def testWrapClass(self):
        # Normally a mirrored value would be the same across devices, but
        # for a test it is convenient to be able to tell the values apart.
        result = distribute_utils.regroup(
            (_nested_value("1"), _nested_value("2")), values.Mirrored)
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 3)
        self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
        self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)

        self.assertIsInstance(result[1], list)
        self.assertLen(result[1], 3)
        self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
        self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
        self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

        # Also test that we can undo the merge using select_replica()
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica(1, result))
        # Values are marked as mirrored, so select_device_mirrored() is allowed.
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica_mirrored(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica_mirrored(1, result))
    def testNested(self):
        result = distribute_utils.regroup(
            (_nested_value("1"), _nested_value("2")))
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 3)
        self._is_per_replica(result[0], ["a1", "a2"])
        self._is_per_replica(result[2], ["h1", "h2"])

        self.assertIsInstance(result[1], list)
        self.assertLen(result[1], 3)
        self._is_per_replica(result[1][0], ["b1", "b2"])
        self._is_per_replica(result[1][2], ["g1", "g2"])

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
        self._is_per_replica(result[1][1]["e"], ["f1", "f2"])

        # Also test that we can undo the merge using select_replica()
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica(1, result))
        # select_device_mirrored() should fail due to non-mirrored values
        with self.assertRaises(TypeError):
            distribute_utils.select_replica_mirrored(0, result)
        with self.assertRaises(TypeError):
            distribute_utils.select_replica_mirrored(1, result)
예제 #3
0
    def _test_input_fn_iterator(self,
                                task_type,
                                task_id,
                                num_gpus,
                                input_fn,
                                expected_values,
                                test_reinitialize=True,
                                ignore_order=False):
        distribution, master_target, config = self._get_test_object(
            task_type, task_id, num_gpus)
        devices = distribution.extended.worker_devices

        with ops.Graph().as_default(), \
             self.cached_session(config=config,
                                 target=master_target) as sess:
            iterator = distribution.make_input_fn_iterator(input_fn)
            sess.run(iterator.initializer)

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = sess.run([
                    distribute_utils.select_replica(r, next_element)
                    for r in range(len(devices))
                ])
                if ignore_order:
                    self.assertCountEqual(list(expected_value),
                                          list(computed_value))
                else:
                    self.assertEqual(list(expected_value),
                                     list(computed_value))

            # error raised by calling optional_get_value on an Optional of None
            with self.assertRaises(errors.InvalidArgumentError):
                next_element = iterator.get_next()
                sess.run([
                    distribute_utils.select_replica(r, next_element)
                    for r in range(len(devices))
                ])

            # After re-initializing the iterator, should be able to iterate again.
            if test_reinitialize:
                sess.run(iterator.initializer)

                for expected_value in expected_values:
                    next_element = iterator.get_next()
                    computed_value = sess.run([
                        distribute_utils.select_replica(r, next_element)
                        for r in range(len(devices))
                    ])
                    if ignore_order:
                        self.assertCountEqual(list(expected_value),
                                              list(computed_value))
                    else:
                        self.assertEqual(list(expected_value),
                                         list(computed_value))
예제 #4
0
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, tuple)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = []
   for i, d in enumerate(colocate_with):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
       updates.append(
           fn(*distribute_utils.select_replica(i, args),
              **distribute_utils.select_replica(i, kwargs)))
   return distribute_utils.update_regroup(self, updates, group)
예제 #5
0
 def _update(self, var, fn, args, kwargs, group):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     updates = []
     for i, v in enumerate(var.values):
         name = "update_%d" % i
         with ops.device(v.device), \
              distribute_lib.UpdateContext(i), \
              ops.name_scope(name):
             # If args and kwargs are not mirrored, the value is returned as is.
             updates.append(
                 fn(v, *distribute_utils.select_replica(i, args),
                    **distribute_utils.select_replica(i, kwargs)))
     return distribute_utils.update_regroup(self, updates, group)
예제 #6
0
        def rewrite_fn(*args):
            """The rewritten step fn running on TPU."""
            del args

            per_replica_inputs = multi_worker_iterator.get_next()
            replicate_inputs = []
            for replica_id in range(self._num_replicas_in_sync):
                select_replica = lambda x: distribute_utils.select_replica(  # pylint: disable=g-long-lambda
                    replica_id, x)  # pylint: disable=cell-var-from-loop
                replicate_inputs.append(
                    (nest.map_structure(select_replica, per_replica_inputs), ))

            replicate_outputs = tpu.replicate(
                run_fn,
                replicate_inputs,
                device_assignment=self._device_assignment,
                xla_options=tpu.XLAOptions(
                    use_spmd_for_xla_partitioning=False))

            # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
            # will flatten it in this case. If run_fn has no tensor outputs,
            # tpu.replicate returns a list of no_ops, we will keep the output as it
            # is.
            if isinstance(replicate_outputs[0], list):
                replicate_outputs = nest.flatten(replicate_outputs)

            return replicate_outputs
예제 #7
0
    def _call_for_each_replica(self, fn, args, kwargs):
        # For now, `fn` must be an @tf.function.
        # TODO(josh11b): Relax this restriction?  Main problem is if
        # (a) executing eagerly, (b) `fn` not @tf.function, and
        # (c) executed frequently.
        assert isinstance(fn, def_function.Function)

        if _outside_run_graph() is not None:
            # Nested case, should just use outer function's context for things like
            # the current replica index.
            # TODO(josh11b): Test this case!
            with MirroredFunctionReplicaContext(self._container_strategy()):
                results = fn(*nest.map_structure(_unwrap_tensors, args),
                             **nest.map_structure(_unwrap_tensors, kwargs))
                return nest.map_structure(_wrap_tensors, results)

        _replica_index.graph_outside_run = ops.get_default_graph()
        return_values = []

        try:
            with MirroredFunctionReplicaContext(self._container_strategy()):
                for index, device in enumerate(self._devices):
                    _replica_index.current = index
                    with ops.device(device):
                        if context.executing_eagerly():
                            # NOTE: These functions need to execute concurrently if they
                            # use a collective op. This is a particular concern with eager
                            # execution.
                            with context.execution_mode(context.ASYNC):
                                return_values.append(
                                    fn(
                                        *distribute_utils.select_replica(
                                            index, args),
                                        **distribute_utils.select_replica(
                                            index, kwargs)))
                        else:
                            return_values.append(
                                fn(
                                    *distribute_utils.select_replica(
                                        index, args),
                                    **distribute_utils.select_replica(
                                        index, kwargs)))
        finally:
            _replica_index.graph_outside_run = None
            _replica_index.current = None

        return distribute_utils.regroup(return_values)
    def testSameId(self):
        foo = object()
        result = distribute_utils.regroup((("a", foo), ("b", foo)))
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 2)
        self._is_per_replica(result[0], ["a", "b"])
        self.assertIs(foo, result[1])

        # Test select_replica(), should undo the merge done by regroup().
        result_0 = distribute_utils.select_replica(0, result)
        self.assertIsInstance(result_0, tuple)
        self.assertLen(result_0, 2)
        self.assertEqual("a", result_0[0])
        self.assertIs(foo, result_0[1])
        result_1 = distribute_utils.select_replica(1, result)
        self.assertIsInstance(result_1, tuple)
        self.assertLen(result_1, 2)
        self.assertEqual("b", result_1[0])
        self.assertIs(foo, result_1[1])
예제 #9
0
      def test_get_next_as_optional(iterator):
        for expected_value in expected_values:
          next_element = iterator.get_next_as_optional()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element.get_value())
              for r in range(len(devices))
          ])

          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])

        next_element = iterator.get_next_as_optional()
        self.assertFalse(self.evaluate(next_element.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
          evaluate([
              distribute_utils.select_replica(r, next_element.get_value())
              for r in range(len(devices))
          ])
예제 #10
0
    def _test_input_fn_iterator(self,
                                iterator,
                                devices,
                                expected_values,
                                sess=None,
                                test_reinitialize=True,
                                ignore_order=False):
        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
        evaluate(iterator.initializer)

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate([
                distribute_utils.select_replica(r, next_element)
                for r in range(len(devices))
            ])
            if ignore_order:
                self.assertCountEqual(expected_value, computed_value)
            else:
                self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([
                distribute_utils.select_replica(r, next_element)
                for r in range(len(devices))
            ])

        # After re-initializing the iterator, should be able to iterate again.
        if test_reinitialize:
            evaluate(iterator.initializer)

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = evaluate([
                    distribute_utils.select_replica(r, next_element)
                    for r in range(len(devices))
                ])
                if ignore_order:
                    self.assertCountEqual(expected_value, computed_value)
                else:
                    self.assertEqual(expected_value, computed_value)
    def testNamedTuple(self):

        # We include toy implementations of Scaffold and EstimatorSpec to
        # avoid a dependency on Estimator here.

        class Scaffold(object):
            pass

        class EstimatorSpec(
                collections.namedtuple(
                    "EstimatorSpec",
                    ["mode", "loss", "train_op", "scaffold"])):
            def __new__(cls, mode, loss, train_op, scaffold=None):
                return super(EstimatorSpec, cls).__new__(cls,
                                                         mode=mode,
                                                         loss=loss,
                                                         train_op=train_op,
                                                         scaffold=scaffold
                                                         or Scaffold())

        with context.graph_mode(), ops.Graph().as_default():
            created_estimator_specs = []

            for device_id in range(3):
                spec = EstimatorSpec(mode=mode_keys.EstimatorModeKeys.TRAIN,
                                     loss=constant_op.constant(device_id / 2),
                                     train_op=array_ops.identity(
                                         constant_op.constant(device_id)))
                created_estimator_specs.append(spec)

            merged_estimator_spec = distribute_utils.regroup(
                created_estimator_specs)

            self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
            self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
                             merged_estimator_spec.mode)
            for device_id in range(3):
                self.assertEqual(created_estimator_specs[device_id].loss,
                                 merged_estimator_spec.loss.values[device_id])
                self.assertEqual(
                    created_estimator_specs[device_id].train_op,
                    merged_estimator_spec.train_op.values[device_id])
                # Scaffold is populated by `EstimatorSpec.__new__`.
                self.assertEqual(
                    created_estimator_specs[device_id].scaffold,
                    merged_estimator_spec.scaffold.values[device_id])
                self.assertIsInstance(
                    created_estimator_specs[device_id].scaffold, Scaffold)
                # Also test that we can undo the merge using select_replica()
                self.assertEqual(
                    created_estimator_specs[device_id],
                    distribute_utils.select_replica(device_id,
                                                    merged_estimator_spec))
예제 #12
0
      def test_get_next(iterator):
        for expected_value in expected_values:
          next_element = iterator.get_next()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])

          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])

        with self.assertRaises(errors.OutOfRangeError):
          next_element = iterator.get_next()
          evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])

        # After re-initializing the iterator, should be able to iterate again.
        if not ops.executing_eagerly_outside_functions():
          evaluate(control_flow_ops.group(iterator.initializer))
        else:
          if api_type == "wrap_into_iterator":
            self.skipTest("unsupported test combination")
          else:
            iterator = iter(dataset)

        for expected_value in expected_values:
          next_element = iterator.get_next()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])
          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])
예제 #13
0
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Remove None at the end of args as they are not replicatable
            # If there are None in the middle we can't do anything about it
            # so let those cases fail.
            # For example when Keras model predict is used they pass the targets as
            # None. We want to handle it here so all client libraries don't have to
            # do this as other strategies can handle None values better.
            while args and args[-1] is None:
                args = args[:-1]

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    distribute_utils.select_replica(i, args),
                    distribute_utils.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if options.experimental_enable_dynamic_batch_size and replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        rank = input_tensor.shape.rank
                    else:
                        rank = np.ndim(input_tensor)
                    maximum_shape = tensor_shape.TensorShape([None] * rank)
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            if options.experimental_bucketizing_dynamic_shape:
                padding_spec = tpu.PaddingSpec.POWER_OF_TWO
            else:
                padding_spec = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    device_assignment=self._device_assignment,
                    maximum_shapes=maximum_shapes,
                    padding_spec=padding_spec,
                    xla_options=tpu.XLAOptions(
                        use_spmd_for_xla_partitioning=False))

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if not isinstance(output, ops.Operation)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None or isinstance(result[0], ops.Operation):
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            return distribute_utils.regroup(replicate_outputs)
예제 #14
0
  def _test_input_iteration(self,
                            input_type,
                            api_type,
                            iteration_type,
                            dataset_or_input_fn,
                            worker_device_pairs,
                            expected_values,
                            strategy,
                            sess=None,
                            split_batch_by=None,
                            input_context=None):
    if iteration_type == "for_loop" and not context.executing_eagerly():
      self.skipTest("unsupported test combination.")

    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
      self.skipTest("unsupported test combination.")

    if api_type == "wrap_into_iterator" and input_type == "input_fn":
      self.skipTest("unsupported test combination.")

    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    input_workers = input_lib.InputWorkers(worker_device_pairs)

    if api_type == "wrap_into_iterator":
      iterator = self._wrap_iterator(
          input_type,
          dataset_or_input_fn,
          input_workers,
          devices,
          split_batch_by,
          strategy,
          input_context=input_context)
    else:
      # wrapping into a dataset:
      dataset = self._wrap_dataset(
          input_type,
          dataset_or_input_fn,
          input_workers,
          split_batch_by,
          strategy,
          input_context=input_context)

      if ops.executing_eagerly_outside_functions():
        iterator = iter(dataset)
      else:
        if isinstance(dataset, input_lib.DistributedDatasetV1):
          iterator = dataset.make_initializable_iterator()
        else:
          self.skipTest("unsupported test combination")

    if isinstance(iterator, composite_tensor.CompositeTensor):
      nest.assert_same_structure(iterator, iterator._type_spec,
                                 expand_composites=True)

    if iteration_type == "get_next":
      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
      if not ops.executing_eagerly_outside_functions():
        evaluate(control_flow_ops.group(iterator.initializer))

      def test_get_next(iterator):
        for expected_value in expected_values:
          next_element = iterator.get_next()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])

          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])

        with self.assertRaises(errors.OutOfRangeError):
          next_element = iterator.get_next()
          evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])

        # After re-initializing the iterator, should be able to iterate again.
        if not ops.executing_eagerly_outside_functions():
          evaluate(control_flow_ops.group(iterator.initializer))
        else:
          if api_type == "wrap_into_iterator":
            self.skipTest("unsupported test combination")
          else:
            iterator = iter(dataset)

        for expected_value in expected_values:
          next_element = iterator.get_next()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element)
              for r in range(len(devices))
          ])
          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])

      def test_get_next_as_optional(iterator):
        for expected_value in expected_values:
          next_element = iterator.get_next_as_optional()
          computed_value = evaluate([
              distribute_utils.select_replica(r, next_element.get_value())
              for r in range(len(devices))
          ])

          self.assertEqual(len(expected_value), len(computed_value))
          for i in range(len(expected_value)):
            self.assertAllEqual(expected_value[i], computed_value[i])

        next_element = iterator.get_next_as_optional()
        self.assertFalse(self.evaluate(next_element.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
          evaluate([
              distribute_utils.select_replica(r, next_element.get_value())
              for r in range(len(devices))
          ])

      test_get_next(iterator)

      # re-initializing the iterator
      if not tf2.enabled():
        self.skipTest("Not testing get_next_as_optional in TF1")
      else:
        if api_type == "wrap_into_iterator":
          self.skipTest("unsupported test combination")
        else:
          iterator = iter(dataset)

      test_get_next_as_optional(iterator)

    if iteration_type == "for_loop" and context.executing_eagerly():
      actual_values = []
      for x in dataset:
        computed_value = self.evaluate(
            [distribute_utils.select_replica(r, x)
             for r in range(len(devices))])
        actual_values.append(computed_value)
      for i, expected_value in enumerate(expected_values):
        self.assertEqual(len(expected_value), len(actual_values[i]))
        for j in range(len(expected_value)):
          self.assertAllEqual(expected_value[j], actual_values[i][j])
 def testOneDevice(self):
     result = distribute_utils.regroup((_nested_value("1"), ))
     # On one device regroup() and select_replica() are basically identity.
     self.assertEqual(_nested_value("1"), result)
     self.assertEqual(_nested_value("1"),
                      distribute_utils.select_replica(0, result))
예제 #16
0
def _call_for_each_replica(distribution, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per replica, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}
    devices = distribution.extended.worker_devices

    thread_local_callables = _get_thread_local_configuration_callable()

    # TODO(isaprykin): Create these threads once instead of during every call.
    threads = []
    for index in range(len(devices)):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = _MirroredReplicaThread(
            distribution, coord, index, devices, variable_creator_fn, fn,
            distribute_utils.caching_scope_local,
            distribute_utils.select_replica(index, args),
            distribute_utils.select_replica(index,
                                            kwargs), thread_local_callables)
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = distribute_utils.regroup(
                        tuple(t.merge_args for t in threads))
                    merge_kwargs = distribute_utils.regroup(
                        tuple(t.merge_kwargs for t in threads))
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    mtt_captured_var_scope = threads[0].captured_var_scope
                    # Capture and merge the control dependencies from all the threads.
                    mtt_captured_control_deps = set()
                    for t in threads:
                        mtt_captured_control_deps.update(
                            t.captured_control_deps)
                    with ops.name_scope(mtt_captured_name_scope),\
                        ops.control_dependencies(mtt_captured_control_deps), \
                        variable_scope.variable_scope(mtt_captured_var_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for r, t in enumerate(threads):
                        t.merge_result = distribute_utils.select_replica(
                            r, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return distribute_utils.regroup(tuple(t.main_result for t in threads))
 def testSelectReplica(self, distribution, synchronization, aggregation):
     with distribution.scope():
         v = variables_lib.Variable(1.,
                                    synchronization=synchronization,
                                    aggregation=aggregation)
     self.assertIs(v, distribute_utils.select_replica(0, v))
예제 #18
0
  def testRaggedSparse(self, distribution, input_type, drop_remainder,
                       defun_type):
    """Test with `RaggedTensor`s and `SparseTensor`s."""
    if not tf2.enabled():
      self.skipTest("Only V2 is supported.")

    defun = {"lambda": lambda f: f,
             "tf_function": def_function.function}[defun_type]
    distribution.extended.experimental_enable_get_next_as_optional = True
    global_batch_size = 8

    def dataset_fn(ctx=None):
      ctx = ctx or distribute_lib.InputContext()
      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
      # Use 20 which isn't divisible by 8 to test partial batch behavior.
      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
      dataset = dataset_ops.DatasetV2.from_tensor_slices({
          "dense": ragged_tensor.to_tensor(),
          "ragged": ragged_tensor,
          "sparse": ragged_tensor.to_sparse(),
      })
      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
      return dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset_or_input_fn = self._create_dataset_or_input_fn(
        input_type, dataset_fn)
    dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
                                 distribution.extended._input_workers,
                                 len(distribution.extended.worker_devices),
                                 distribution)
    # Assert that the tensors are rebatched and sparsity is preserved.
    per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
    self.assertAllEqual(
        distribute_utils.select_replica(0, per_replica_batch["dense"]),
        [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
    self.assertAllEqual(
        distribute_utils.select_replica(1, per_replica_batch["dense"]),
        [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
    # Transitively check the ragged and sparse tensors by densification.
    for i in range(2):
      self.assertLen(
          distribute_utils.select_replica(i,
                                          per_replica_batch["ragged"]).values,
          6)
      self.assertAllEqual(
          distribute_utils.select_replica(
              i, per_replica_batch["ragged"]).to_tensor(),
          distribute_utils.select_replica(i, per_replica_batch["dense"]))
      self.assertLen(
          distribute_utils.select_replica(i,
                                          per_replica_batch["sparse"]).indices,
          6)
      self.assertAllEqual(
          sparse_ops.sparse_tensor_to_dense(
              distribute_utils.select_replica(i, per_replica_batch["sparse"])),
          distribute_utils.select_replica(i, per_replica_batch["dense"]))
    # Iterate through all the batches and sum them up.
    def sum_batch(per_replica_features):
      """Sums the `PerReplica` values in the `per_replica_features` map."""

      def map_fn(per_replica_values):
        per_replica_sums = distribution.run(
            (lambda x: math_ops.reduce_sum(x.values)) if all(
                map(sparse_tensor.is_sparse, per_replica_values.values)) else
            math_ops.reduce_sum, (per_replica_values,))
        return distribution.reduce(
            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)

      return nest.map_structure(map_fn, per_replica_features)

    def _reduce(state, batch):
      sums = sum_batch(batch)
      return {name: value + sums[name] for name, value in state.items()}

    def sum_for_loop(dataset):
      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
      for batch in dataset:
        sums = _reduce(sums, batch)
      return sums

    def sum_while_loop(iterator, reduce_fn):
      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
      while True:
        try:
          sums = reduce_fn(sums, iterator)
        except (StopIteration, errors.OutOfRangeError):
          return sums

    while_sums = sum_while_loop(
        iter(dataset),
        defun(lambda state, iterator: _reduce(state, next(iterator))))
    self.assertAllEqual(
        nest.flatten(while_sums),
        # When there's no partial batch, the sum is smaller.
        [200. if drop_remainder else 310.] * 3)
    for_sums = defun(sum_for_loop)(dataset)
    # For loops always call get next as optional inside tf functions, so we
    # expect 310 here when using an input function (as there are 5 batches of
    # size 4 round robined over 2 replicas.
    expected_for_sum = 200.
    if (not drop_remainder or (
        defun_type == "tf_function" and input_type == "input_fn")):
      expected_for_sum = 310.
    self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)