def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn,
                              expected_values):
    distribution, master_target, config = self._get_test_object(
        task_type, task_id, num_gpus)
    devices = distribution.extended.worker_devices

    with ops.Graph().as_default(), \
         self.cached_session(config=config,
                             target=master_target) as sess:
      iterator = distribution.make_input_fn_iterator(input_fn)
      sess.run(iterator.initialize())

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = sess.run([values.select_replica(r, next_element)
                                   for r in range(len(devices))])
        self.assertEqual(expected_value, computed_value)

      with self.assertRaises(errors.OutOfRangeError):
        next_element = iterator.get_next()
        sess.run([values.select_replica(r, next_element)
                  for r in range(len(devices))])

      # After re-initializing the iterator, should be able to iterate again.
      sess.run(iterator.initialize())

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = sess.run([values.select_replica(r, next_element)
                                   for r in range(len(devices))])
        self.assertEqual(expected_value, computed_value)
示例#2
0
  def testNested(self):
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")))
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"])
    self._is_per_replica(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"])
    self._is_per_replica(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
示例#3
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")),
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
  def _test_input_fn_iterator(self,
                              iterator,
                              devices,
                              expected_values,
                              sess=None,
                              test_reinitialize=True):
    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
    evaluate(iterator.initialize())

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])

    # After re-initializing the iterator, should be able to iterate again.
    if test_reinitialize:
      evaluate(iterator.initialize())

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = evaluate([
            values.select_replica(r, next_element) for r in range(len(devices))
        ])
        self.assertEqual(expected_value, computed_value)
示例#5
0
def _tpu_run(strategy, fn, args, kwargs):
  """Common implementation of TPUStrategy.experimental_run_v2."""
  if context.executing_eagerly() and not ops.inside_function():
    raise NotImplementedError(
        "Eager mode not supported in TPUStrategy outside TF functions.")

  if kwargs is None:
    kwargs = {}

  # Used to re-structure flattened output tensors from `tpu.replicate()`
  # into a structured format.
  result = [[]]

  def replicated_fn(replica_id, replica_args, replica_kwargs):
    """Wraps user function to provide replica ID and `Tensor` inputs."""
    with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
      result[0] = fn(*replica_args, **replica_kwargs)
    return result[0]

  replicate_inputs = []  # By replica.
  for i in range(strategy.num_replicas_in_sync):
    replicate_inputs.append(
        [constant_op.constant(i, dtype=dtypes.int32),
         values.select_replica(i, args),
         values.select_replica(i, kwargs)])

  # Construct and pass `maximum_shapes` so that we could support dynamic
  # shapes using dynamic padder.
  if replicate_inputs:
    maximum_shapes = []
    flattened_list = nest.flatten(replicate_inputs[0])
    for input_tensor in flattened_list:
      maximum_shapes.append(input_tensor.get_shape())
    maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                           maximum_shapes)
  else:
    maximum_shapes = None

  with strategy.scope():
    replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs,
                                      maximum_shapes=maximum_shapes)

  # Remove all no ops that may have been added during 'tpu.replicate()'
  if isinstance(result[0], list):
    result[0] = [
        output for output in result[0] if tensor_util.is_tensor(output)
    ]

  # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
  replicate_outputs = [
      nest.pack_sequence_as(result[0], nest.flatten(replica_output))
      for replica_output in replicate_outputs
  ]

  device_map = strategy.extended._device_map  # pylint: disable=protected-access
  return values.regroup(device_map, replicate_outputs)
示例#6
0
  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      devices = []
      created_estimator_specs = []

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        devices.append(_device_str(device_id))
        created_estimator_specs.append(spec)

      device_map = values.ReplicaDeviceMap(devices)
      merged_estimator_spec = values.regroup(
          device_map, created_estimator_specs)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEqual(created_estimator_specs[device_id].loss,
                         merged_estimator_spec.loss.get(d))
        self.assertEqual(created_estimator_specs[device_id].train_op,
                         merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEqual(created_estimator_specs[device_id].scaffold,
                         merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_replica()
        self.assertEqual(created_estimator_specs[device_id],
                         values.select_replica(device_id,
                                               merged_estimator_spec))
示例#7
0
  def _test_iterator(self, sess, iterator, devices, expected_values):
    next_element = iterator.get_next()
    for r, device in enumerate(devices):
      v = values.select_replica(r, next_element)
      # The `v` here can be a tuple.
      for element in nest.flatten(v):
        self.assertTrue(element.device in device)

    for expected_value in expected_values:
      t = [values.select_replica(r, next_element) for r in range(len(devices))]
      actual = sess.run(t)
      self.assertEqual(expected_value, actual)

    with self.assertRaises(errors.OutOfRangeError):
      sess.run([values.select_replica(r, next_element)
                for r in range(len(devices))])
示例#8
0
  def experimental_run_v2(self, fn, args=(), kwargs=None):
    """See base class."""
    if context.executing_eagerly() and not ops.inside_function():
      raise NotImplementedError(
          "Eager mode not supported in TPUStrategy outside TF functions.")

    if kwargs is None:
      kwargs = {}

    # Used to re-structure flattened output tensors from `tpu.replicate()`
    # into a structured format.
    result = [[]]

    def replicated_fn(replica_id, replica_args, replica_kwargs):
      """Wraps user function to provide replica ID and `Tensor` inputs."""
      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
        result[0] = fn(*replica_args, **replica_kwargs)
      return result[0]

    replicate_inputs = []  # By replica.
    for i in range(self.num_replicas_in_sync):
      replicate_inputs.append(
          [constant_op.constant(i, dtype=dtypes.int32),
           values.select_replica(i, args),
           values.select_replica(i, kwargs)])

    with self.scope():
      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

    # Remove all no ops that may have been added during 'tpu.replicate()'
    if isinstance(result[0], list):
      result[0] = [
          output for output in result[0] if tensor_util.is_tensor(output)
      ]

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_output))
        for replica_output in replicate_outputs
    ]

    device_map = self.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
示例#9
0
  def testSameId(self):
    foo = object()
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map, (("a", foo), ("b", foo)))
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_replica(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_replica(), should undo the merge done by regroup().
    result_0 = values.select_replica(0, result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_replica(1, result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
示例#10
0
  def _test_iterator(self,
                     input_type,
                     dataset_fn,
                     worker_device_pairs,
                     expected_values,
                     sess=None,
                     split_batch_by=None,
                     enable_get_next_as_optional=False):
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    iterator = self._create_iterator(
        input_type, dataset_fn, worker_device_pairs, devices, split_batch_by,
        enable_get_next_as_optional)

    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertEqual(len(expected_value), len(computed_value))
      for i in range(len(expected_value)):
        self.assertAllEqual(expected_value[i], computed_value[i])

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])

    # After re-initializing the iterator, should be able to iterate again.
    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertEqual(len(expected_value), len(computed_value))
      for i in range(len(expected_value)):
        self.assertAllEqual(expected_value[i], computed_value[i])
示例#11
0
  def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
                     expected_values, sess=None, split_batch_by=None):
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

    if input_type == "input_fn":
      input_contexts = [
          distribute_lib.InputContext() for _ in worker_device_pairs]
      input_fn = lambda _: dataset_fn()
      iterator = input_lib.InputFunctionIterator(
          input_fn, input_workers, input_contexts)
    else:
      iterator = input_lib.DatasetIterator(
          dataset_fn(), input_workers, split_batch_by)

    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate([values.select_replica(r, next_element)
                for r in range(len(devices))])

    # After re-initializing the iterator, should be able to iterate again.
    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)
示例#12
0
  def testOneDevice(self):
    device_map = values.ReplicaDeviceMap((_device_str(0),))
    result = values.regroup(device_map, (_nested_value("1"),))
    # On one device regroup() and select_replica() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      device_map = values.ReplicaDeviceMap((d,))
    mirrored = values.MirroredVariable(None, device_map, (v,),
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(device_map, (v,))
    self.assertIs(mirrored, result)
示例#13
0
  def experimental_run(self, fn, input_iterator=None):
    """See base class."""
    if context.executing_eagerly():
      raise NotImplementedError("Eager mode not supported in TPUStrategy.")

    if self.extended._disable_training_loop_on_host:  # pylint: disable=protected-access
      raise NotImplementedError(
          "`experimental_run` is not compatible with "
          "`_disable_training_loop_on_host=True`")

    if input_iterator is None:
      inputs = []
    else:
      inputs = input_iterator.get_next()

    result = [None]
    def replicated_fn(replica_id, inputs):
      """Wraps user function to provide replica ID and `Tensor` inputs."""
      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
        if input_iterator is None:
          result[0] = fn()
        else:
          result[0] = fn(inputs)
      return result[0]

    replicate_inputs = []  # By replica.
    for i in range(self.num_replicas_in_sync):
      replicate_inputs.append(
          [constant_op.constant(i, dtype=dtypes.int32),
           values.select_replica(i, inputs)])

    with self.scope():
      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
        for replica_outputs in replicate_outputs]

    device_map = self.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
示例#14
0
    def rewrite_fn(*args):
      """The rewritten step fn running on TPU."""
      del args

      per_replica_inputs = multi_worker_iterator.get_next()
      replicate_inputs = []
      for replica_id in range(self._num_replicas_in_sync):
        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
        replicate_inputs.append((nest.map_structure(
            select_replica, per_replica_inputs),))

      replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
      # will flatten it in this case. If run_fn has no tensor outputs,
      # tpu.replicate returns a list of no_ops, we will keep the output as it
      # is.
      if isinstance(replicate_outputs[0], list):
        replicate_outputs = nest.flatten(replicate_outputs)

      return replicate_outputs
示例#15
0
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    values.select_replica(i, args),
                    values.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        maximum_shape = input_tensor.get_shape()
                    else:
                        maximum_shape = tensor_shape.TensorShape(
                            np.shape(input_tensor))
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    maximum_shapes=maximum_shapes)

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if tensor_util.is_tensor(output)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None:
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            device_map = self._device_map  # pylint: disable=protected-access
            return values.regroup(device_map, replicate_outputs)
示例#16
0
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Remove None at the end of args as they are not replicatable
            # If there are None in the middle we can't do anything about it
            # so let those cases fail.
            # For example when Keras model predict is used they pass the targets as
            # None. We want to handle it here so all client libraries don't have to
            # do this as other strategies can handle None values better.
            while args and args[-1] is None:
                args = args[:-1]

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    values.select_replica(i, args),
                    values.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if self.experimental_enable_dynamic_batch_size and replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        maximum_shape = input_tensor.get_shape()
                    else:
                        maximum_shape = tensor_shape.TensorShape(
                            np.shape(input_tensor))
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    device_assignment=self._device_assignment,
                    maximum_shapes=maximum_shapes)

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if tensor_util.is_tensor(output)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None:
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            device_map = self._device_map  # pylint: disable=protected-access
            return values.regroup(device_map, replicate_outputs)
示例#17
0
  def _test_input_iteration(self,
                            input_type,
                            api_type,
                            iteration_type,
                            dataset_fn,
                            worker_device_pairs,
                            expected_values,
                            sess=None,
                            split_batch_by=None,
                            enable_get_next_as_optional=False):
    if iteration_type == "for_loop" and not context.executing_eagerly():
      self.skipTest("unsupported test combination.")

    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
      self.skipTest("unsupported test combination.")

    if api_type == "wrap_into_dataset" and input_type == "input_fn":
      self.skipTest("unsupported test combination.")

    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

    if api_type == "wrap_into_iterator":
      iterator = self._wrap_iterator(
          input_type, dataset_fn, input_workers, devices, split_batch_by,
          enable_get_next_as_optional)
    else:
      # wrapping into a dataset:
      given_dataset = dataset_fn(distribute_lib.InputContext())
      dataset = self._wrap_dataset(input_type, given_dataset, input_workers,
                                   split_batch_by, enable_get_next_as_optional)

      if context.executing_eagerly():
        iterator = iter(dataset)
      else:
        # The dataset can be a tf.data.DatasetV1Adapter instance since we wrap
        # tf.data.DatasetV1 as a tf.data.DatasetV1Adapter instance when we
        # autoshard the dataset.
        if not isinstance(dataset, (dataset_ops.DatasetV1,
                                    dataset_ops.DatasetV1Adapter)):
          iterator = iter(dataset)
        else:
          iterator = dataset.make_one_shot_iterator()

    if iteration_type == "get_next":
      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
      if isinstance(iterator, input_lib.DistributedIteratorV1):
        evaluate(control_flow_ops.group(iterator.initialize()))
      else:
        evaluate(control_flow_ops.group(iterator._initializer))

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])
        self.assertEqual(len(expected_value), len(computed_value))
        for i in range(len(expected_value)):
          self.assertAllEqual(expected_value[i], computed_value[i])

      with self.assertRaises(errors.OutOfRangeError):
        next_element = iterator.get_next()
        evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])

      # After re-initializing the iterator, should be able to iterate again.
      if isinstance(iterator, input_lib.DistributedIteratorV1):
        evaluate(control_flow_ops.group(iterator.initialize()))
      else:
        evaluate(control_flow_ops.group(iterator._initializer))

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])
        self.assertEqual(len(expected_value), len(computed_value))
        for i in range(len(expected_value)):
          self.assertAllEqual(expected_value[i], computed_value[i])

    if iteration_type == "for_loop" and context.executing_eagerly():
      actual_values = []
      for x in dataset:
        computed_value = self.evaluate(
            [values.select_replica(r, x) for r in range(len(devices))])
        actual_values.append(computed_value)
      for i, expected_value in enumerate(expected_values):
        self.assertEqual(len(expected_value), len(actual_values[i]))
        for j in range(len(expected_value)):
          self.assertAllEqual(expected_value[j], actual_values[i][j])
示例#18
0
def _call_for_each_replica(distribution, devices, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    devices: the devices to run `fn` on (logical device 0 for each replica).
    fn: function to run (will be run once per replica, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}

    # TODO(isaprykin): Create these threads once instead of during every call.
    threads = []
    for index in range(len(devices)):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = _MirroredReplicaThread(distribution, coord, index, devices,
                                   variable_creator_fn, fn,
                                   values.select_replica(index, args),
                                   values.select_replica(index, kwargs))
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = values.regroup(
                        tuple(t.merge_args for t in threads))
                    merge_kwargs = values.regroup(
                        tuple(t.merge_kwargs for t in threads))
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    mtt_captured_var_scope = threads[0].captured_var_scope
                    # Capture and merge the control dependencies from all the threads.
                    mtt_captured_control_deps = set()
                    for t in threads:
                        mtt_captured_control_deps.update(
                            t.captured_control_deps)
                    with ops.name_scope(mtt_captured_name_scope),\
                        ops.control_dependencies(mtt_captured_control_deps), \
                        variable_scope.variable_scope(mtt_captured_var_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for r, t in enumerate(threads):
                        t.merge_result = values.select_replica(r, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return values.regroup(tuple(t.main_result for t in threads))
示例#19
0
    def _test_input_iteration(self,
                              input_type,
                              api_type,
                              iteration_type,
                              dataset_or_input_fn,
                              worker_device_pairs,
                              expected_values,
                              strategy,
                              sess=None,
                              split_batch_by=None,
                              input_context=None):
        if iteration_type == "for_loop" and not context.executing_eagerly():
            self.skipTest("unsupported test combination.")

        if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
            self.skipTest("unsupported test combination.")

        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        input_workers = input_lib.InputWorkers(worker_device_pairs)

        if api_type == "wrap_into_iterator":
            iterator = self._wrap_iterator(input_type,
                                           dataset_or_input_fn,
                                           input_workers,
                                           devices,
                                           split_batch_by,
                                           strategy,
                                           input_context=input_context)
        else:
            # wrapping into a dataset:
            dataset = self._wrap_dataset(input_type,
                                         dataset_or_input_fn,
                                         input_workers,
                                         split_batch_by,
                                         strategy,
                                         input_context=input_context)

            if context.executing_eagerly():
                iterator = iter(dataset)
            else:
                if isinstance(dataset, input_lib.DistributedDatasetV1):
                    iterator = dataset.make_initializable_iterator()
                else:
                    self.skipTest("unsupported test combination")

        if iteration_type == "get_next":
            evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
            if isinstance(iterator, input_lib.DistributedIteratorV1):
                evaluate(control_flow_ops.group(iterator.initializer))

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = evaluate([
                    values.select_replica(r, next_element)
                    for r in range(len(devices))
                ])
                self.assertEqual(len(expected_value), len(computed_value))
                for i in range(len(expected_value)):
                    self.assertAllEqual(expected_value[i], computed_value[i])

            with self.assertRaises(errors.OutOfRangeError):
                next_element = iterator.get_next()
                evaluate([
                    values.select_replica(r, next_element)
                    for r in range(len(devices))
                ])

            # After re-initializing the iterator, should be able to iterate again.
            if isinstance(iterator, input_lib.DistributedIteratorV1):
                evaluate(control_flow_ops.group(iterator.initializer))
            else:
                evaluate(control_flow_ops.group(iterator._initializer))

            for expected_value in expected_values:
                next_element = iterator.get_next()
                computed_value = evaluate([
                    values.select_replica(r, next_element)
                    for r in range(len(devices))
                ])
                self.assertEqual(len(expected_value), len(computed_value))
                for i in range(len(expected_value)):
                    self.assertAllEqual(expected_value[i], computed_value[i])

        if iteration_type == "for_loop" and context.executing_eagerly():
            actual_values = []
            for x in dataset:
                computed_value = self.evaluate(
                    [values.select_replica(r, x) for r in range(len(devices))])
                actual_values.append(computed_value)
            for i, expected_value in enumerate(expected_values):
                self.assertEqual(len(expected_value), len(actual_values[i]))
                for j in range(len(expected_value)):
                    self.assertAllEqual(expected_value[j], actual_values[i][j])
示例#20
0
    def testRaggedSparse(self, distribution, input_type, drop_remainder,
                         defun):
        """Test with `RaggedTensor`s and `SparseTensor`s."""
        if not tf2.enabled():
            self.skipTest("Only V2 is supported.")

        distribution.extended.experimental_enable_get_next_as_optional = True
        global_batch_size = 8

        def dataset_fn(ctx=None):
            ctx = ctx or distribute_lib.InputContext()
            batch_size = ctx.get_per_replica_batch_size(global_batch_size)
            # Use 20 which isn't divisible by 8 to test partial batch behavior.
            row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
            ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
                np.repeat(np.arange(20, dtype=np.float32), row_lengths),
                row_lengths)
            dataset = dataset_ops.DatasetV2.from_tensor_slices({
                "dense":
                ragged_tensor.to_tensor(),
                "ragged":
                ragged_tensor,
                "sparse":
                ragged_tensor.to_sparse(),
            })
            dataset = dataset.shard(ctx.num_input_pipelines,
                                    ctx.input_pipeline_id)
            return dataset.batch(batch_size, drop_remainder=drop_remainder)

        dataset_or_input_fn = self._create_dataset_or_input_fn(
            input_type, dataset_fn)
        dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
                                     distribution.extended._input_workers,
                                     len(distribution.extended.worker_devices),
                                     distribution)
        # Assert that the tensors are rebatched and sparsity is preserved.
        per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
        self.assertAllEqual(
            values.select_replica(0, per_replica_batch["dense"]),
            [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
        self.assertAllEqual(
            values.select_replica(1, per_replica_batch["dense"]),
            [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
        # Transitively check the ragged and sparse tensors by densification.
        for i in range(2):
            self.assertLen(
                values.select_replica(i, per_replica_batch["ragged"]).values,
                6)
            self.assertAllEqual(
                values.select_replica(i,
                                      per_replica_batch["ragged"]).to_tensor(),
                values.select_replica(i, per_replica_batch["dense"]))
            self.assertLen(
                values.select_replica(i, per_replica_batch["sparse"]).indices,
                6)
            self.assertAllEqual(
                sparse_ops.sparse_tensor_to_dense(
                    values.select_replica(i, per_replica_batch["sparse"])),
                values.select_replica(i, per_replica_batch["dense"]))
        # Iterate through all the batches and sum them up.
        def sum_batch(per_replica_features):
            """Sums the `PerReplica` values in the `per_replica_features` map."""
            def map_fn(per_replica_values):
                per_replica_sums = distribution.experimental_run_v2(
                    (lambda x: math_ops.reduce_sum(x.values)) if all(
                        map(sparse_tensor.is_sparse,
                            per_replica_values.values)) else
                    math_ops.reduce_sum, (per_replica_values, ))
                return distribution.reduce(reduce_util.ReduceOp.SUM,
                                           per_replica_sums,
                                           axis=None)

            return nest.map_structure(map_fn, per_replica_features)

        def _reduce(state, batch):
            sums = sum_batch(batch)
            return {name: value + sums[name] for name, value in state.items()}

        def sum_for_loop(dataset):
            sums = {"dense": 0., "ragged": 0., "sparse": 0.}
            for batch in dataset:
                sums = _reduce(sums, batch)
            return sums

        def sum_while_loop(iterator, reduce_fn):
            sums = {"dense": 0., "ragged": 0., "sparse": 0.}
            while True:
                try:
                    sums = reduce_fn(sums, iterator)
                except (StopIteration, errors.OutOfRangeError):
                    return sums

        sums = sum_while_loop(
            iter(dataset),
            defun(lambda state, iterator: _reduce(state, next(iterator))))
        self.assertDictEqual(sums, defun(sum_for_loop)(dataset))
        self.assertAllEqual(
            nest.flatten(sums),
            # When there's no partial batch, the sum is smaller.
            [200. if input_type == "dataset" and drop_remainder else 310.] * 3)
示例#21
0
def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
  """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    device_map: the DeviceMap with the devices to run `fn` on.
    fn: function to run (will be run once per replica, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
  # TODO(josh11b): Add this option once we add synchronization to variable
  # creation. Until then, this is pretty unsafe to use.
  run_concurrently = False
  if not context.executing_eagerly():
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every call.
  threads = []
  for index in range(device_map.num_replicas_in_graph):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = _MirroredReplicaThread(
        distribution, coord, index, device_map, variable_creator_fn, fn,
        values.select_replica(index, args),
        values.select_replica(index, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

  # When `fn` starts `should_run` event is set on _MirroredReplicaThread
  # (`MRT`) threads. The execution waits until
  # `MRT.has_paused` is set, which indicates that either `fn` is
  # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
  # complete, then `MRT.done` is set to True.  Otherwise, arguments
  # of `get_replica_context().merge_call` from all paused threads are grouped
  # and the `merge_fn` is performed.  Results of the
  # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
  # Each such `get_replica_context().merge_call` call returns the
  # `MRT.merge_result` for that thread when `MRT.should_run` event
  # is reset again. Execution of `fn` resumes.

  try:
    with coord.stop_on_exception():
      all_done = False
      while not all_done and not coord.should_stop():
        done = []
        if run_concurrently:
          for t in threads:
            t.should_run.set()
          for t in threads:
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        else:
          for t in threads:
            t.should_run.set()
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        if coord.should_stop():
          return None
        all_done = all(done)
        if not all_done:
          if any(done):
            raise RuntimeError("Some replicas made a different number of "
                               "replica_context().merge_call() calls.")
          # get_replica_context().merge_call() case
          merge_args = values.regroup(
              device_map, tuple(t.merge_args for t in threads))
          merge_kwargs = values.regroup(
              device_map, tuple(t.merge_kwargs for t in threads))
          # We capture the name_scope of the MRT when we call merge_fn
          # to ensure that if we have opened a name scope in the MRT,
          # it will be respected when executing the merge function. We only
          # capture the name_scope from the first MRT and assume it is
          # the same for all other MRTs.
          mtt_captured_name_scope = threads[0].captured_name_scope
          # Capture and merge the control dependencies from all the threads.
          mtt_captured_control_deps = set()
          for t in threads:
            mtt_captured_control_deps.update(t.captured_control_deps)
          with ops.name_scope(mtt_captured_name_scope),\
              ops.control_dependencies(mtt_captured_control_deps):
            merge_result = threads[0].merge_fn(distribution, *merge_args,
                                               **merge_kwargs)
          for r, t in enumerate(threads):
            t.merge_result = values.select_replica(r, merge_result)
  finally:
    for t in threads:
      t.should_run.set()
    coord.join(threads)

  return values.regroup(device_map, tuple(t.main_result for t in threads))