def _test_iterator(self, input_fn, worker_device_pairs, expected_values,
                     sess=None):
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    iterator = values.InputFunctionIterator(input_fn, worker_device_pairs)

    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

    evaluate(iterator.initialize())

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate([values.select_device(d, next_element) for d in devices])

    # After re-initializing the iterator, should be able to iterate again.
    evaluate(iterator.initialize())

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, computed_value)
    def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
        if not context.executing_eagerly():
            per_device_dataset = values.PerDeviceDataset(
                dataset, devices, prefetch_on_device=True)
            iterator = per_device_dataset.make_one_shot_iterator()

            # With prefetching, we cannot guarantee which input ends up on which
            # device, so we verify that the complete set seen on all devices is
            # correct, and equal numbers are distributed to each device.
            combined_actual = []
            combined_expected = []
            for expected_value in expected_values:
                next_element = iterator.get_next()
                combined_actual.extend(
                    self.evaluate([
                        values.select_device(d, next_element) for d in devices
                    ]))
                combined_expected.extend(expected_value)

            self.assertEqual(set(combined_expected), set(combined_actual))

            with self.assertRaises(errors.OutOfRangeError):
                next_element = iterator.get_next()
                self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
Beispiel #3
0
    def _test_iterator(self,
                       input_fn,
                       worker_device_pairs,
                       expected_values,
                       sess=None):
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        iterator = values.InputFunctionIterator(input_fn, worker_device_pairs)

        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([values.select_device(d, next_element) for d in devices])

        # After re-initializing the iterator, should be able to iterate again.
        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)
    def testNested(self):
        result = values.regroup({
            _device_str(0): _nested_value("1"),
            _device_str(1): _nested_value("2")
        })
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"])
        self._is_per_device(result[2], ["h1", "h2"])

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"])
        self._is_per_device(result[1][2], ["g1", "g2"])

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"])
        self._is_per_device(result[1][1]["e"], ["f1", "f2"])

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # select_device_mirrored() should fail due to non-mirrored values
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(0), result)
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(1), result)
    def testWrapClass(self):
        # Normally a mirrored value would be the same across devices, but
        # for a test it is convenient to be able to tell the values apart.
        result = values.regroup(
            {
                _device_str(0): _nested_value("1"),
                _device_str(1): _nested_value("2")
            }, values.Mirrored)
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
        self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
        self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
        self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # Values are marked as mirrored, so select_device_mirrored() is allowed.
        self.assertEqual(_nested_value("1"),
                         values.select_device_mirrored(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device_mirrored(_device_str(1), result))
Beispiel #6
0
  def testNested(self):
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")})
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"])
    self._is_per_device(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"])
    self._is_per_device(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"])
    self._is_per_device(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
Beispiel #7
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")},
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
Beispiel #8
0
  def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
    per_device_dataset = values.PerDeviceDataset(
        dataset, devices, prefetch_on_device=False)
    iterator = per_device_dataset.make_one_shot_iterator()

    for expected_value in expected_values:
      next_element = iterator.get_next()
      actual = self.evaluate([
          values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, actual)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      self.evaluate([
          values.select_device(d, next_element) for d in devices])
Beispiel #9
0
  def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
    per_device_dataset = values.PerDeviceDataset(
        dataset, devices, prefetch_on_device=False)
    iterator = per_device_dataset.make_one_shot_iterator()

    for expected_value in expected_values:
      next_element = iterator.get_next()
      actual = self.evaluate([
          values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, actual)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      self.evaluate([
          values.select_device(d, next_element) for d in devices])
Beispiel #10
0
  def _test_iterator(self, iterator, devices, expected_values):
    next_element = iterator.get_next()
    for device in devices:
      v = values.select_device(device, next_element)
      # The `v` here can be a tuple.
      for element in nest.flatten(v):
        self.assertTrue(element.device in device)

    for expected_value in expected_values:
      actual = self.evaluate(
          [values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, actual)

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate([values.select_device(d, next_element) for d in devices])
Beispiel #11
0
    def _test_iterator(self, sess, iterator, devices, expected_values):
        next_element = iterator.get_next()
        for device in devices:
            v = values.select_device(device, next_element)
            # The `v` here can be a tuple.
            for element in nest.flatten(v):
                self.assertTrue(element.device in device)

        for expected_value in expected_values:
            actual = sess.run(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, actual)

        with self.assertRaises(errors.OutOfRangeError):
            sess.run([values.select_device(d, next_element) for d in devices])
    def testNamedTupleEstimatorSpec(self):
        with context.graph_mode(), ops.Graph().as_default():
            created_estimator_specs = {}
            to_regroup = {}

            for device_id in range(3):
                spec = model_fn_lib.EstimatorSpec(
                    mode=model_fn_lib.ModeKeys.TRAIN,
                    loss=constant_op.constant(device_id / 2),
                    train_op=array_ops.identity(
                        constant_op.constant(device_id)))
                created_estimator_specs[device_id] = spec
                to_regroup[_device_str(device_id)] = spec

            merged_estimator_spec = values.regroup(to_regroup)

            self.assertTrue(
                isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
            self.assertEquals(model_fn_lib.ModeKeys.TRAIN,
                              merged_estimator_spec.mode)
            for device_id in range(3):
                d = _device_str(device_id)
                self.assertEquals(created_estimator_specs[device_id].loss,
                                  merged_estimator_spec.loss.get(d))
                self.assertEquals(created_estimator_specs[device_id].train_op,
                                  merged_estimator_spec.train_op.get(d))
                # Scaffold is populated by `EstimatorSpec.__new__`.
                self.assertEquals(created_estimator_specs[device_id].scaffold,
                                  merged_estimator_spec.scaffold.get(d))
                # Also test that we can undo the merge using select_device()
                self.assertEquals(
                    created_estimator_specs[device_id],
                    values.select_device(_device_str(device_id),
                                         merged_estimator_spec))
Beispiel #13
0
  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      created_estimator_specs = {}
      to_regroup = {}

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        created_estimator_specs[device_id] = spec
        to_regroup[_device_str(device_id)] = spec

      merged_estimator_spec = values.regroup(to_regroup)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEquals(created_estimator_specs[device_id].loss,
                          merged_estimator_spec.loss.get(d))
        self.assertEquals(created_estimator_specs[device_id].train_op,
                          merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEquals(created_estimator_specs[device_id].scaffold,
                          merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_device()
        self.assertEquals(created_estimator_specs[device_id],
                          values.select_device(_device_str(device_id),
                                               merged_estimator_spec))
Beispiel #14
0
    def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
        if not context.executing_eagerly():
            per_device_dataset = values.PerDeviceDataset(
                dataset, devices, prefetch_on_device=True)
            iterator = per_device_dataset.make_initializable_iterator()
            self.evaluate([iterator.initializer])

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
                self.assertEqual(expected_value, computed_value)

            with self.assertRaises(errors.OutOfRangeError):
                next_element = iterator.get_next()
                self.evaluate(
                    [values.select_device(d, next_element) for d in devices])
Beispiel #15
0
  def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
    if not context.executing_eagerly():
      per_device_dataset = values.PerDeviceDataset(
          dataset, devices, prefetch_on_device=True)
      iterator = per_device_dataset.make_initializable_iterator()
      self.evaluate([iterator.initializer])

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = self.evaluate(
            [values.select_device(d, next_element) for d in devices])
        self.assertEqual(expected_value, computed_value)

      with self.assertRaises(errors.OutOfRangeError):
        next_element = iterator.get_next()
        self.evaluate([
            values.select_device(d, next_element) for d in devices])
  def _test_iterator(self, devices, dataset, expected_values):
    per_replica_dataset = values.PerReplicaDataset(dataset, devices)
    if context.executing_eagerly():
      iterator = per_replica_dataset.make_one_shot_iterator()
    else:
      iterator = per_replica_dataset.make_initializable_iterator()
      self.evaluate([iterator.initializer])

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = self.evaluate(
          [values.select_device(d, next_element) for d in devices])
      self.assertEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      self.evaluate([
          values.select_device(d, next_element) for d in devices])
Beispiel #17
0
    def _test_iterator(self, devices, dataset, expected_values):
        per_replica_dataset = values.PerReplicaDataset(dataset, devices)
        if context.executing_eagerly():
            iterator = per_replica_dataset.make_one_shot_iterator()
        else:
            iterator = per_replica_dataset.make_initializable_iterator()
            self.evaluate([iterator.initializer])

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = self.evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            self.evaluate(
                [values.select_device(d, next_element) for d in devices])
Beispiel #18
0
  def testSameId(self):
    foo = object()
    result = values.regroup({_device_str(0): ("a", foo),
                             _device_str(1): ("b", foo)})
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_device(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_device(), should undo the merge done by regroup().
    result_0 = values.select_device(_device_str(0), result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_device(_device_str(1), result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
Beispiel #19
0
  def testSameId(self):
    foo = object()
    result = values.regroup({_device_str(0): ("a", foo),
                             _device_str(1): ("b", foo)})
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_device(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_device(), should undo the merge done by regroup().
    result_0 = values.select_device(_device_str(0), result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_device(_device_str(1), result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
Beispiel #20
0
  def _call_and_check(self, model_fn, inputs, expected_result, defuns,
                      two_variables=False):
    cpu_dev = device_util.canonicalize("CPU:0")
    gpu_dev = device_util.canonicalize("GPU:0")
    devices = [cpu_dev, gpu_dev]
    dist = mirrored_strategy.MirroredStrategy(devices)

    with dist.scope():
      mock_model = MockModel(two_variables)
      self.evaluate(variables.global_variables_initializer())

      result = dist.call_for_each_tower(model_fn, mock_model, *inputs,
                                        run_concurrently=False)
      for device in devices:
        device_result = values.select_device(device, result)
        device_expected_result = values.select_device(device, expected_result)
        self.assertAllClose(device_expected_result,
                            self.evaluate(device_result))

      for defun in defuns:
        self.assertEqual(set(mock_model.variables), set(defun.variables))
  def _call_and_check(self, model_fn, inputs, expected_result, defuns,
                      two_variables=False):
    cpu_dev = device_util.canonicalize("CPU:0")
    gpu_dev = device_util.canonicalize("GPU:0")
    devices = [cpu_dev, gpu_dev]
    dist = mirrored_strategy.MirroredStrategy(devices)

    with dist.scope():
      mock_model = MockModel(two_variables)
      self.evaluate(variables.global_variables_initializer())

      result = dist.call_for_each_tower(model_fn, mock_model, *inputs,
                                        run_concurrently=False)
      for device in devices:
        device_result = values.select_device(device, result)
        device_expected_result = values.select_device(device, expected_result)
        self.assertAllClose(device_expected_result,
                            self.evaluate(device_result))

      for defun in defuns:
        self.assertEqual(set(mock_model.variables), set(defun.variables))
Beispiel #22
0
  def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
    if not context.executing_eagerly():
      per_device_dataset = values.PerDeviceDataset(
          dataset, devices, prefetch_on_device=True)
      iterator = per_device_dataset.make_one_shot_iterator()

      # With prefetching, we cannot guarantee which input ends up on which
      # device, so we verify that the complete set seen on all devices is
      # correct, and equal numbers are distributed to each device.
      combined_actual = []
      combined_expected = []
      for expected_value in expected_values:
        next_element = iterator.get_next()
        combined_actual.extend(self.evaluate([
            values.select_device(d, next_element) for d in devices]))
        combined_expected.extend(expected_value)

      self.assertEqual(set(combined_expected), set(combined_actual))

      with self.assertRaises(errors.OutOfRangeError):
        next_element = iterator.get_next()
        self.evaluate([
            values.select_device(d, next_element) for d in devices])
Beispiel #23
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Beispiel #24
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Beispiel #25
0
    def replicated_model_fn(features, labels, mode, params=None, config=None):
        """Replicated version of `model_fn` to be used instead."""

        # if input_fn is not an intance of perDeviceDataset then split batch evenly
        # between devices.
        if not isinstance(labels, values.PerDevice):
            if isinstance(features, dict):
                for v in features.values():
                    assert not isinstance(v, values.PerDevice)
            else:
                assert not isinstance(features, values.PerDevice)
            feature_shards, label_shards = _split_batch(
                features, labels, len(devices), device=consolidation_device)
        else:
            feature_shards, label_shards = zip(
                *(values.select_device(device, (features, labels))
                  for device in devices))

        tower_specs = _get_loss_towers(model_fn=model_fn,
                                       mode=mode,
                                       features=feature_shards,
                                       labels=label_shards,
                                       params=params,
                                       loss_reduction=loss_reduction,
                                       config=config,
                                       devices=devices,
                                       local_ps_devices=ps_devices)

        if mode == model_fn_lib.ModeKeys.TRAIN:
            train_op = _minimize_towers(tower_specs)
            return _train_spec(tower_specs,
                               train_op,
                               aggregation_device=consolidation_device)
        elif mode == model_fn_lib.ModeKeys.EVAL:
            return _eval_spec(tower_specs,
                              aggregation_device=consolidation_device)
        elif mode == model_fn_lib.ModeKeys.PREDICT:
            return _predict_spec(tower_specs,
                                 aggregation_device=consolidation_device)
def _call_for_each_tower(distribution, fn, *args, **kwargs):
  """Run `fn` in separate threads, once per tower/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    *args: positional arguments for `fn`
    **kwargs: keyword arguments for `fn`.
        `"run_concurrently"`: Boolean indicating whether executions of `fn`
           can be run concurrently (under eager execution only), defaults to
           `True`.

  Returns:
    Merged return value of `fn` across all towers.

  Raises:
    RuntimeError: If fn() calls get_tower_context().merge_call() a different
        number of times from the available devices.
  """
  run_concurrently = kwargs.pop("run_concurrently", True)
  if not context.executing_eagerly():
    # Lots of TF library code isn't thread-safe in graph mode, and
    # there is little to be gained by turning on multithreading when
    # constructing a graph.
    run_concurrently = False
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()
  elif run_concurrently is None:
    run_concurrently = True

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

  # When `fn` starts `should_run` event is set on _MirroredTowerThread
  # (`MTT`) threads. The execution waits until
  # `MTT.has_paused` is set, which indicates that either `fn` is
  # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
  # complete, then `MTT.done` is set to True.  Otherwise, arguments
  # of `get_tower_context().merge_call` from all paused threads are grouped
  # and the `merge_fn` is performed.  Results of the
  # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
  # Each such `get_tower_context().merge_call` call returns the
  # `MTT.merge_result` for that thread when `MTT.should_run` event
  # is reset again. Execution of `fn` resumes.

  try:
    with coord.stop_on_exception():
      all_done = False
      while not all_done and not coord.should_stop():
        done = []
        if run_concurrently:
          for t in threads:
            t.should_run.set()
          for t in threads:
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        else:
          for t in threads:
            t.should_run.set()
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        if coord.should_stop():
          return None
        all_done = all(done)
        if not all_done:
          if any(done):
            raise RuntimeError("Some towers made a different number of "
                               "tower_context().merge_call() calls.")
          # get_tower_context().merge_call() case
          merge_args = values.regroup({t.device: t.merge_args for t in threads})
          merge_kwargs = values.regroup(
              {t.device: t.merge_kwargs for t in threads})
          # We capture the name_scope of the MTT when we call merge_fn
          # to ensure that if we have opened a name scope in the MTT,
          # it will be respected when executing the merge function. We only
          # capture the name_scope from the first MTT and assume it is
          # the same for all other MTTs.
          mtt_captured_name_scope = threads[0].captured_name_scope
          with ops.name_scope(mtt_captured_name_scope):
            merge_result = threads[0].merge_fn(distribution, *merge_args,
                                               **merge_kwargs)
          for t in threads:
            t.merge_result = values.select_device(t.device, merge_result)
  finally:
    for t in threads:
      t.should_run.set()
    coord.join(threads)

  return values.regroup({t.device: t.main_result for t in threads})
def _call_for_each_replica(distribution, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}

    # TODO(isaprykin): Create these threads once instead of during every run()
    # call.
    threads = []
    for index, d in enumerate(distribution.worker_devices):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = MirroredStrategy._MirroredReplicaThread(  # pylint: disable=protected-access
            distribution, coord, d, variable_creator_fn, fn,
            *values.select_device(d, args), **values.select_device(d, kwargs))
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = values.regroup(
                        {t.device: t.merge_args
                         for t in threads})
                    merge_kwargs = values.regroup(
                        {t.device: t.merge_kwargs
                         for t in threads})
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    with ops.name_scope(mtt_captured_name_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for t in threads:
                        t.merge_result = values.select_device(
                            t.device, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return values.regroup({t.device: t.main_result for t in threads})
Beispiel #28
0
    def _call_for_each_tower(self, fn, *args, **kwargs):
        """Run `fn` in separate threads, once per tower/worker device.

    Args:
      fn: function to run (will be run once per device, each in its own thread).
      *args: positional arguments for `fn`
      **kwargs: keyword arguments for `fn`.
          `"run_concurrently"`: Boolean indicating whether executions of `fn`
             can be run concurrently (under eager execution only), defaults to
             `True`.

    Returns:
      Merged return value of `fn` across all towers.

    Raises:
      RuntimeError: If fn() calls get_tower_context().merge_call() a different
          number of times for when called for different devices.
    """
        run_concurrently = kwargs.pop("run_concurrently", True)
        if not context.executing_eagerly():
            # Lots of TF library code isn't thread-safe in graph mode, and
            # there is little to be gained by turning on multithreading when
            # constructing a graph.
            run_concurrently = False
            # Needed for per-thread device, etc. contexts in graph mode.
            ops.get_default_graph().switch_to_thread_local()
        elif run_concurrently is None:
            run_concurrently = True

        coord = coordinator.Coordinator(
            clean_stop_exception_types=(_RequestedStop, ))

        shared_variable_store = {}

        # TODO(isaprykin): Create these threads once instead of during every run()
        # call.
        threads = []
        for index, d in enumerate(self._devices):
            variable_creator_fn = shared_variable_creator.make_fn(
                shared_variable_store, index)
            t = MirroredStrategy._MirroredTowerThread(
                self, coord, d, variable_creator_fn, fn,
                *values.select_device(d, args),
                **values.select_device(d, kwargs))
            threads.append(t)

        for t in threads:
            t.start()

        # When `fn` starts `should_run` event is set on _MirroredTowerThread
        # (`MTT`) threads. The execution waits until
        # `MTT.has_paused` is set, which indicates that either `fn` is
        # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
        # complete, then `MTT.done` is set to True.  Otherwise, arguments
        # of `get_tower_context().merge_call` from all paused threads are grouped
        # and the `merge_fn` is performed.  Results of the
        # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
        # Each such `get_tower_context().merge_call` call returns the
        # `MTT.merge_result` for that thread when `MTT.should_run` event
        # is reset again. Execution of `fn` resumes.

        try:
            with coord.stop_on_exception():
                all_done = False
                while not all_done and not coord.should_stop():
                    done = []
                    if run_concurrently:
                        for t in threads:
                            t.should_run.set()
                        for t in threads:
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    else:
                        for t in threads:
                            t.should_run.set()
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    if coord.should_stop():
                        return None
                    all_done = all(done)
                    if not all_done:
                        if any(done):
                            raise RuntimeError(
                                "Some towers made a different number of "
                                "tower_context().merge_call() calls.")
                        # get_tower_context().merge_call() case
                        merge_args = values.regroup(
                            {t.device: t.merge_args
                             for t in threads})
                        merge_kwargs = values.regroup(
                            {t.device: t.merge_kwargs
                             for t in threads})
                        merge_result = threads[0].merge_fn(
                            self, *merge_args, **merge_kwargs)
                        for t in threads:
                            t.merge_result = values.select_device(
                                t.device, merge_result)
        finally:
            for t in threads:
                t.should_run.set()
            coord.join(threads)

        return values.regroup({t.device: t.main_result for t in threads})