Пример #1
0
    def distributed_function(x, y, sample_weights, learning_phase=None):
      """A single step of the distributed execution across replicas."""
      del learning_phase

      # TODO(b/129653859):  Simplify after PerReplica can be the input of
      # `def_function.function`.  `regroup` calls and re-wrapping in
      # PerReplica won't be needed then.
      if isinstance(strategy, one_device_strategy.OneDeviceStrategy):
        device_map = values.SingleDeviceMap(devices[0])
        wrap_class = lambda d, x: x
      else:
        device_map = values.ReplicaDeviceMap(devices)
        wrap_class = values.PerReplica

      # Transform each lists of lists of values into per replica objects
      # in the case of mirrored strategy.  For example, for 2 replicas:
      # [[x0, y0], [x1, y1]] > [PerReplica(d0:x0, d1:x1),
      #                         PerReplica(d0:y0, d1:y1)]
      x = values.regroup(device_map, x, wrap_class)
      y = values.regroup(device_map, y, wrap_class) if y else None
      sample_weights = values.regroup(device_map, sample_weights,
                                      wrap_class) if sample_weights else None

      # Call `Model.{train,test,predict}_on_batch` on every replica passing
      # PerReplicas as arguments.  On every replica inside this call, each
      # PerReplica object will return the value for that replica.  The outputs
      # are PerReplicas too.
      outputs = strategy.experimental_run_v2(
          per_replica_function, args=(x, y, sample_weights))
      # Out of PerReplica outputs reduce or pick values to return.
      all_outputs = unwrap_outputs(
          strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
      return all_outputs
Пример #2
0
    def get_next(self, name=None):
        """Returns the next input from the iterator for all replicas."""
        if not self._enable_get_next_as_optional:
            replicas = []
            for i, worker in enumerate(self._input_workers.worker_devices):
                if name is not None:
                    d = tf_device.DeviceSpec.from_string(worker)
                    new_name = "%s_%s_%d" % (name, d.job, d.task)
                else:
                    new_name = None
                with ops.device(worker):
                    # Make `replicas` a flat list of values across all replicas.
                    replicas.extend(
                        self._iterators[i].get_next_as_list_deprecated(
                            new_name))
            return values.regroup(self._input_workers.device_map, replicas)

        out_of_range_replicas = []

        def out_of_range_fn(worker_index, device):
            """This function will throw an OutOfRange error."""
            # As this will be only called when there is no data left, so calling
            # get_next() will trigger an OutOfRange error.
            data = self._iterators[worker_index].get_next(device)
            out_of_range_replicas.append(data)
            return data

        global_has_value, replicas = _get_next_as_optional(
            self, self._strategy)
        results = []
        for i, worker in enumerate(self._input_workers.worker_devices):
            with ops.device(worker):
                devices = self._input_workers.compute_devices_for_worker(i)
                for j, device in enumerate(devices):
                    with ops.device(device):
                        # pylint: disable=undefined-loop-variable
                        # pylint: disable=cell-var-from-loop
                        # It is fine for the lambda to capture variables from the loop as
                        # the lambda is executed in the loop as well.
                        result = control_flow_ops.cond(
                            global_has_value, lambda: replicas[i][j],
                            lambda: out_of_range_fn(i, device))
                        # pylint: enable=cell-var-from-loop
                        # pylint: enable=undefined-loop-variable
                        results.append(result)
        replicas = results

        # Some dimensions in `replicas` will become unknown after we conditionally
        # return the real tensors or the dummy tensors. We fix the input shapes by
        # using the shapes from `out_of_range_replicas` because it is calling
        # get_next() inside.
        flattened_replicas = nest.flatten(replicas)
        for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
            flattened_replicas[i].set_shape(replica_data.get_shape())
        replicas = nest.pack_sequence_as(replicas, flattened_replicas)

        return values.regroup(self._input_workers.device_map, replicas)
Пример #3
0
  def get_next(self, name=None):
    """Returns the next input from the iterator for all replicas."""
    if not self._enable_get_next_as_optional:
      replicas = []
      for i, worker in enumerate(self._input_workers.worker_devices):
        if name is not None:
          d = tf_device.DeviceSpec.from_string(worker)
          new_name = "%s_%s_%d" % (name, d.job, d.task)
        else:
          new_name = None
        with ops.device(worker):
          # Make `replicas` a flat list of values across all replicas.
          replicas.extend(
              self._iterators[i].get_next_as_list_deprecated(new_name))
      return values.regroup(self._input_workers.device_map, replicas)

    out_of_range_replicas = []
    def out_of_range_fn(worker_index, device):
      """This function will throw an OutOfRange error."""
      # As this will be only called when there is no data left, so calling
      # get_next() will trigger an OutOfRange error.
      data = self._iterators[worker_index].get_next(device)
      out_of_range_replicas.append(data)
      return data

    global_has_value, replicas = _get_next_as_optional(self, self._strategy)
    results = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      with ops.device(worker):
        devices = self._input_workers.compute_devices_for_worker(i)
        for j, device in enumerate(devices):
          with ops.device(device):
            # pylint: disable=undefined-loop-variable
            # pylint: disable=cell-var-from-loop
            # It is fine for the lambda to capture variables from the loop as
            # the lambda is executed in the loop as well.
            result = control_flow_ops.cond(global_has_value,
                                           lambda: replicas[i][j],
                                           lambda: out_of_range_fn(i, device))
            # pylint: enable=cell-var-from-loop
            # pylint: enable=undefined-loop-variable
            results.append(result)
    replicas = results

    # Some dimensions in `replicas` will become unknown after we conditionally
    # return the real tensors or the dummy tensors. We fix the input shapes by
    # using the shapes from `out_of_range_replicas` because it is calling
    # get_next() inside.
    flattened_replicas = nest.flatten(replicas)
    for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
      flattened_replicas[i].set_shape(replica_data.get_shape())
    replicas = nest.pack_sequence_as(replicas, flattened_replicas)

    return values.regroup(self._input_workers.device_map, replicas)
Пример #4
0
 def testWrapAListOfTwoTuples(self):
   device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
   result = values.regroup(device_map, [("1", "2"), ("3", "4")])
   self.assertIsInstance(result, tuple)
   self.assertEqual(2, len(result))
   self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
   self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
Пример #5
0
  def testNested(self):
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")))
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"])
    self._is_per_replica(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"])
    self._is_per_replica(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
Пример #6
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")),
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
Пример #7
0
  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      devices = []
      created_estimator_specs = []

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        devices.append(_device_str(device_id))
        created_estimator_specs.append(spec)

      device_map = values.ReplicaDeviceMap(devices)
      merged_estimator_spec = values.regroup(
          device_map, created_estimator_specs)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEqual(created_estimator_specs[device_id].loss,
                         merged_estimator_spec.loss.get(d))
        self.assertEqual(created_estimator_specs[device_id].train_op,
                         merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEqual(created_estimator_specs[device_id].scaffold,
                         merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_replica()
        self.assertEqual(created_estimator_specs[device_id],
                         values.select_replica(device_id,
                                               merged_estimator_spec))
  def reduce(self, reduce_op, per_replica_value, destinations):
    """Reduce `per_replica_value` to `destinations`.

    It runs the reduction operation defined by `reduce_op` and put the
    result on `destinations`.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
        per_replica_value will be reduced.
      per_replica_value: a PerReplica object or a tensor with device set.
      destinations: the reduction destinations.

    Returns:
      a Mirrored object.

    Raises:
      ValueError: if per_replica_value can't be converted to a PerReplica
        object or if destinations aren't strings, Variables or DistributedValues
    """
    if not isinstance(per_replica_value, value_lib.DistributedValues):
      per_replica_value = _make_tensor_into_per_replica(per_replica_value)

    validate_destinations(destinations)

    # Shortcut if `per_replica_value` only contains one value.
    if self._num_between_graph_workers == 1 and len(
        per_replica_value.values) == 1 and _devices_match(
            per_replica_value, destinations):
      return value_lib.regroup(
          per_replica_value.device_map,
          per_replica_value.values,
          wrap_class=value_lib.Mirrored)

    return self.reduce_implementation(reduce_op, per_replica_value,
                                      destinations)
Пример #9
0
    def reduce_implementation(self, reduce_op, per_replica_value,
                              destinations):
        all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
        device_map, logical_device = get_device_map_from(destinations)
        devices = device_map.logical_to_actual_devices(logical_device)

        if (isinstance(all_reduced, value_lib.Mirrored)
                and all_reduced.device_map is device_map
                and all_reduced.logical_device == logical_device):
            return all_reduced

        # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
        # utility to access component for a particular device.
        if not isinstance(all_reduced, value_lib.Mirrored):
            all_reduced = value_lib.Mirrored(
                value_lib.SingleDeviceMap(all_reduced.device), [all_reduced])

        index = []
        with ops.control_dependencies(all_reduced.values):
            for d in devices:
                with ops.device(d):
                    if d in all_reduced.devices:
                        index.append(array_ops.identity(all_reduced.get(d)))
                    else:
                        # TODO(josh11b): Once we add support for model parallelism, get the
                        # copy from the corresponding replica instead of the primary.
                        index.append(array_ops.identity(all_reduced.primary))

        return value_lib.regroup(device_map,
                                 index,
                                 wrap_class=value_lib.Mirrored)
Пример #10
0
  def testNested(self):
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")))
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"])
    self._is_per_replica(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"])
    self._is_per_replica(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
Пример #11
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map,
                            (_nested_value("1"), _nested_value("2")),
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_replica()
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))
    self.assertEqual(_nested_value("2"),
                     values.select_replica(1, result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
Пример #12
0
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
    values = [
        _make_indexed_slices(values, indices, dense_shape, d) for d in devices
    ]
    return value_lib.regroup(value_lib.ReplicaDeviceMap(devices),
                             values,
                             wrap_class=value_lib.Mirrored)
Пример #13
0
  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      devices = []
      created_estimator_specs = []

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        devices.append(_device_str(device_id))
        created_estimator_specs.append(spec)

      device_map = values.ReplicaDeviceMap(devices)
      merged_estimator_spec = values.regroup(
          device_map, created_estimator_specs)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEqual(created_estimator_specs[device_id].loss,
                         merged_estimator_spec.loss.get(d))
        self.assertEqual(created_estimator_specs[device_id].train_op,
                         merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEqual(created_estimator_specs[device_id].scaffold,
                         merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_replica()
        self.assertEqual(created_estimator_specs[device_id],
                         values.select_replica(device_id,
                                               merged_estimator_spec))
 def testMirroredContainer(self):
     if context.num_gpus() < 1 and context.executing_eagerly():
         self.skipTest(
             "A GPU is not available for this test in eager mode.")
     v, device_map, mirrored = _make_mirrored()
     result = values.regroup(device_map, v)
     self.assertIs(mirrored, result)
 def _experimental_distribute_values_from_function(self, value_fn):
   per_replica_values = []
   for replica_id in range(self._num_replicas_in_sync):
     per_replica_values.append(value_fn(
         distribute_lib.ValueContext(replica_id,
                                     self._num_replicas_in_sync)))
   return values.regroup(per_replica_values, always_wrap=True)
Пример #16
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v,
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Пример #17
0
def _tpu_run(strategy, fn, args, kwargs):
    """Common implementation of TPUStrategy.experimental_run_v2."""
    if context.executing_eagerly() and not ops.inside_function():
        raise NotImplementedError(
            "Eager mode not supported in TPUStrategy outside TF functions.")

    if kwargs is None:
        kwargs = {}

    # Used to re-structure flattened output tensors from `tpu.replicate()`
    # into a structured format.
    result = [[]]

    def replicated_fn(replica_id, replica_args, replica_kwargs):
        """Wraps user function to provide replica ID and `Tensor` inputs."""
        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
            result[0] = fn(*replica_args, **replica_kwargs)
        return result[0]

    replicate_inputs = []  # By replica.
    for i in range(strategy.num_replicas_in_sync):
        replicate_inputs.append([
            constant_op.constant(i, dtype=dtypes.int32),
            values.select_replica(i, args),
            values.select_replica(i, kwargs)
        ])

    # Construct and pass `maximum_shapes` so that we could support dynamic
    # shapes using dynamic padder.
    if replicate_inputs:
        maximum_shapes = []
        flattened_list = nest.flatten(replicate_inputs[0])
        for input_tensor in flattened_list:
            maximum_shapes.append(input_tensor.get_shape())
        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                               maximum_shapes)
    else:
        maximum_shapes = None

    with strategy.scope():
        replicate_outputs = tpu.replicate(replicated_fn,
                                          replicate_inputs,
                                          maximum_shapes=maximum_shapes)

    # Remove all no ops that may have been added during 'tpu.replicate()'
    if isinstance(result[0], list):
        result[0] = [
            output for output in result[0] if tensor_util.is_tensor(output)
        ]

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_output))
        for replica_output in replicate_outputs
    ]

    device_map = strategy.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
Пример #18
0
  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
                                          initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self._unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been reduced, wrap them in a Mirrored
      # container, else in a PerReplica container.
      if reduce_op is None:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Пример #19
0
    def tpu_function(args, kwargs):
      """TF Function used to replicate the user computation."""
      if kwargs is None:
        kwargs = {}

      # Used to re-structure flattened output tensors from `tpu.replicate()`
      # into a structured format.
      result = [[]]

      def replicated_fn(replica_id, replica_args, replica_kwargs):
        """Wraps user function to provide replica ID and `Tensor` inputs."""
        with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
          result[0] = fn(*replica_args, **replica_kwargs)
        return result[0]

      replicate_inputs = []  # By replica.
      for i in range(strategy.num_replicas_in_sync):
        replicate_inputs.append(
            [constant_op.constant(i, dtype=dtypes.int32),
             values.select_replica(i, args),
             values.select_replica(i, kwargs)])

      # Construct and pass `maximum_shapes` so that we could support dynamic
      # shapes using dynamic padder.
      if replicate_inputs:
        maximum_shapes = []
        flattened_list = nest.flatten(replicate_inputs[0])
        for input_tensor in flattened_list:
          if tensor_util.is_tensor(input_tensor):
            maximum_shape = input_tensor.get_shape()
          else:
            maximum_shape = tensor_shape.TensorShape(np.shape(input_tensor))
          maximum_shapes.append(maximum_shape)
        maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                               maximum_shapes)
      else:
        maximum_shapes = None

      with strategy.scope():
        replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs,
                                          maximum_shapes=maximum_shapes)

      # Remove all no ops that may have been added during 'tpu.replicate()'
      if isinstance(result[0], list):
        result[0] = [
            output for output in result[0] if tensor_util.is_tensor(output)
        ]

      # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
      if result[0] is None:
        replicate_outputs = [None] * len(replicate_outputs)
      else:
        replicate_outputs = [
            nest.pack_sequence_as(result[0], nest.flatten(replica_output))
            for replica_output in replicate_outputs
        ]
      device_map = self._device_map  # pylint: disable=protected-access
      return values.regroup(device_map, replicate_outputs)
Пример #20
0
  def testOneDevice(self):
    device_map = values.ReplicaDeviceMap((_device_str(0),))
    result = values.regroup(device_map, (_nested_value("1"),))
    # On one device regroup() and select_replica() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      device_map = values.ReplicaDeviceMap((d,))
    mirrored = values.MirroredVariable(None, device_map, (v,),
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(device_map, (v,))
    self.assertIs(mirrored, result)
Пример #21
0
  def testOneDevice(self):
    device_map = values.ReplicaDeviceMap((_device_str(0),))
    result = values.regroup(device_map, (_nested_value("1"),))
    # On one device regroup() and select_replica() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      device_map = values.ReplicaDeviceMap((d,))
    mirrored = values.MirroredVariable(None, device_map, (v,),
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(device_map, (v,))
    self.assertIs(mirrored, result)
Пример #22
0
def _tpu_run(strategy, fn, args, kwargs):
  """Common implementation of TPUStrategy.experimental_run_v2."""
  if context.executing_eagerly() and not ops.inside_function():
    raise NotImplementedError(
        "Eager mode not supported in TPUStrategy outside TF functions.")

  if kwargs is None:
    kwargs = {}

  # Used to re-structure flattened output tensors from `tpu.replicate()`
  # into a structured format.
  result = [[]]

  def replicated_fn(replica_id, replica_args, replica_kwargs):
    """Wraps user function to provide replica ID and `Tensor` inputs."""
    with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
      result[0] = fn(*replica_args, **replica_kwargs)
    return result[0]

  replicate_inputs = []  # By replica.
  for i in range(strategy.num_replicas_in_sync):
    replicate_inputs.append(
        [constant_op.constant(i, dtype=dtypes.int32),
         values.select_replica(i, args),
         values.select_replica(i, kwargs)])

  # Construct and pass `maximum_shapes` so that we could support dynamic
  # shapes using dynamic padder.
  if replicate_inputs:
    maximum_shapes = []
    flattened_list = nest.flatten(replicate_inputs[0])
    for input_tensor in flattened_list:
      maximum_shapes.append(input_tensor.get_shape())
    maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                           maximum_shapes)
  else:
    maximum_shapes = None

  with strategy.scope():
    replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs,
                                      maximum_shapes=maximum_shapes)

  # Remove all no ops that may have been added during 'tpu.replicate()'
  if isinstance(result[0], list):
    result[0] = [
        output for output in result[0] if tensor_util.is_tensor(output)
    ]

  # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
  replicate_outputs = [
      nest.pack_sequence_as(result[0], nest.flatten(replica_output))
      for replica_output in replicate_outputs
  ]

  device_map = strategy.extended._device_map  # pylint: disable=protected-access
  return values.regroup(device_map, replicate_outputs)
def per_device_dataset(batch, devices):
    index = {}

    def get_ith(i_):
        return lambda x: x[i_]

    for i, d in enumerate(devices):
        index[d] = nest.map_structure(get_ith(i), batch)

    return values.regroup(index)
def _fake_mirrored(value, devices):
    """Create a faked Mirrored object for testing.

  All components of the returned Mirrored have the same objects, which is not
  true in reality.
  """
    devices = _get_devices(devices)
    values = []
    for d in devices:
        with ops.device(d):
            values.append(array_ops.identity(value))
    return value_lib.regroup(values, wrap_class=value_lib.Mirrored)
Пример #25
0
def per_device_dataset(iterator, devices):
    batch = iterator.get_next()
    print(batch)
    index = {}

    def get_ith(i):
        return lambda x: x[i]

    for i, d in enumerate(devices):
        index[d] = nest.map_structure(get_ith(i), batch)

    return values.regroup(index)
Пример #26
0
def per_device_dataset(iterator, devices):
    """ Split a batch features into per-device features """
    batch = iterator.get_next()
    index = {}

    def get_ith(i_):
        return lambda x: x[i_]

    for i, d in enumerate(devices):
        index[d] = nest.map_structure(get_ith(i), batch)

    return values.regroup(index)
Пример #27
0
  def testNamedTuple(self):

    # We include toy implementations of Scaffold and EstimatorSpec to
    # avoid a dependency on Estimator here.

    class Scaffold(object):
      pass

    class EstimatorSpec(collections.namedtuple(
        "EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])):

      def __new__(cls, mode, loss, train_op, scaffold=None):
        return super(EstimatorSpec, cls).__new__(
            cls, mode=mode, loss=loss, train_op=train_op,
            scaffold=scaffold or Scaffold())

    with context.graph_mode(), ops.Graph().as_default():
      devices = []
      created_estimator_specs = []

      for device_id in range(3):
        spec = EstimatorSpec(
            mode=mode_keys.EstimatorModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        devices.append(_device_str(device_id))
        created_estimator_specs.append(spec)

      device_map = values.ReplicaDeviceMap(devices)
      merged_estimator_spec = values.regroup(
          device_map, created_estimator_specs)

      self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
      self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
                       merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEqual(created_estimator_specs[device_id].loss,
                         merged_estimator_spec.loss.get(d))
        self.assertEqual(created_estimator_specs[device_id].train_op,
                         merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEqual(created_estimator_specs[device_id].scaffold,
                         merged_estimator_spec.scaffold.get(d))
        self.assertIsInstance(created_estimator_specs[device_id].scaffold,
                              Scaffold)
        # Also test that we can undo the merge using select_replica()
        self.assertEqual(created_estimator_specs[device_id],
                         values.select_replica(device_id,
                                               merged_estimator_spec))
Пример #28
0
 def loop_body(has_data, data, state):
   """Executes `reduce_fn` in a loop till the dataset is empty."""
   # data is list of lists here. where each list corresponds to one worker.
   # TODO(b/130570614): Add support for the multiworker and TPU pods use
   # case.
   if self._input_workers.num_workers == 1:
     data = data[0]
   else:
     raise ValueError("Dataset iteration within a tf.function is"
                      " not supported for multiple workers.")
   per_replica_data = values.regroup(self._input_workers.device_map, data)
   state = reduce_fn(state, per_replica_data)
   has_data, data = _get_next_as_optional(iterator, self._strategy)
   return has_data, data, state
Пример #29
0
 def loop_body(has_data, data, state):
   """Executes `reduce_fn` in a loop till the dataset is empty."""
   del has_data  # Unused.
   # data is list of lists here. where each list corresponds to one worker.
   # TODO(b/130570614): Add support for the multiworker and TPU pods use
   # case.
   if self._input_workers.num_workers == 1:
     data = data[0]
   else:
     raise ValueError("Dataset iteration within a tf.function is"
                      " not supported for multiple workers.")
   state = reduce_fn(state, values.regroup(data))
   has_data, data = _get_next_as_optional(iterator, self._strategy)
   return has_data, data, state
Пример #30
0
  def get_next(self, name=None):
    """Returns the next input from the iterator for all replicas."""
    replicas = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      if name is not None:
        d = tf_device.DeviceSpec.from_string(worker)
        new_name = "%s_%s_%d" % (name, d.job, d.task)
      else:
        new_name = None
      with ops.device(worker):
        # Make `replicas` a flat list of values across all replicas.
        replicas.extend(self._iterators[i].get_next_as_list(new_name))

    return values.regroup(self._input_workers.device_map, replicas)
def simple_broadcast(value, destinations, always_mirrored=False):
  """Broadcast `value` to `destinations` using simple copies."""
  device_map, logical_device = get_device_map_from(destinations)
  devices = device_map.logical_to_actual_devices(logical_device)
  if len(devices) == 1 and not always_mirrored:
    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
        value, devices[0])
  else:
    value_updates = []
    for d in devices:
      value_updates.append(
          cross_device_utils.copy_tensor_or_indexed_slices_to_device(
              value, d))
    return value_lib.regroup(
        device_map, value_updates, wrap_class=value_lib.Mirrored)
Пример #32
0
    def get_next(self, name=None):
        """Scatter the input across hosts and devices."""
        replicas = []
        for worker, iterator in zip(self._input_workers.worker_devices,
                                    self._iterators):
            if name is not None:
                d = tf_device.DeviceSpec.from_string(worker)
                new_name = "%s_%s_%d" % (name, d.job, d.task)
            else:
                new_name = None
            with ops.device(worker):
                data_per_worker = iterator.get_next_as_list(name=new_name)
                # Append to replicas to get a flat list of values indexed by replica.
                replicas.extend(data_per_worker)

        return values.regroup(self._input_workers.device_map, replicas)
Пример #33
0
  def get_next(self, name=None):
    """Scatter the input across hosts and devices."""
    replicas = []
    for worker, iterator in zip(self._input_workers.worker_devices,
                                self._iterators):
      if name is not None:
        d = tf_device.DeviceSpec.from_string(worker)
        new_name = "%s_%s_%d" % (name, d.job, d.task)
      else:
        new_name = None
      with ops.device(worker):
        data_per_worker = iterator.get_next_as_list(name=new_name)
        # Append to replicas to get a flat list of values indexed by replica.
        replicas.extend(data_per_worker)

    return values.regroup(self._input_workers.device_map, replicas)
Пример #34
0
def _make_per_replica(values, devices, regroup=False):
  devices = _get_devices(devices)
  assert len(values) == len(devices)

  # We simulate the result of regroup called on PerReplica which strips the
  # PerReplica wrapper if it has only one value.
  if len(values) == 1 and regroup:
    with ops.device(devices[0]):
      placed_v = array_ops.identity(values[0])
    return placed_v

  index = []
  for d, v in zip(devices, values):
    with ops.device(d):
      placed_v = array_ops.identity(v)
    index.append(placed_v)
  return value_lib.regroup(index)
Пример #35
0
    def experimental_run_v2(self, fn, args=(), kwargs=None):
        """See base class."""
        if context.executing_eagerly() and not ops.inside_function():
            raise NotImplementedError(
                "Eager mode not supported in TPUStrategy outside TF functions."
            )

        if kwargs is None:
            kwargs = {}

        # Used to re-structure flattened output tensors from `tpu.replicate()`
        # into a structured format.
        result = [[]]

        def replicated_fn(replica_id, replica_args, replica_kwargs):
            """Wraps user function to provide replica ID and `Tensor` inputs."""
            with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
                result[0] = fn(*replica_args, **replica_kwargs)
            return result[0]

        replicate_inputs = []  # By replica.
        for i in range(self.num_replicas_in_sync):
            replicate_inputs.append([
                constant_op.constant(i, dtype=dtypes.int32),
                values.select_replica(i, args),
                values.select_replica(i, kwargs)
            ])

        with self.scope():
            replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

        # Remove all no ops that may have been added during 'tpu.replicate()'
        if isinstance(result[0], list):
            result[0] = [
                output for output in result[0] if tensor_util.is_tensor(output)
            ]

        # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
        replicate_outputs = [
            nest.pack_sequence_as(result[0], nest.flatten(replica_output))
            for replica_output in replicate_outputs
        ]

        device_map = self.extended._device_map  # pylint: disable=protected-access
        return values.regroup(device_map, replicate_outputs)
Пример #36
0
    def experimental_run(self, fn, input_iterator=None):
        """See base class."""
        if context.executing_eagerly():
            raise NotImplementedError(
                "Eager mode not supported in TPUStrategy.")

        if self.extended._disable_training_loop_on_host:  # pylint: disable=protected-access
            raise NotImplementedError(
                "`experimental_run` is not compatible with "
                "`_disable_training_loop_on_host=True`")

        if input_iterator is None:
            inputs = []
        else:
            inputs = input_iterator.get_next()

        result = [None]

        def replicated_fn(replica_id, inputs):
            """Wraps user function to provide replica ID and `Tensor` inputs."""
            with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
                if input_iterator is None:
                    result[0] = fn()
                else:
                    result[0] = fn(inputs)
            return result[0]

        replicate_inputs = []  # By replica.
        for i in range(self.num_replicas_in_sync):
            replicate_inputs.append([
                constant_op.constant(i, dtype=dtypes.int32),
                values.select_replica(i, inputs)
            ])

        with self.scope():
            replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

        # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
        replicate_outputs = [
            nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
            for replica_outputs in replicate_outputs
        ]

        device_map = self.extended._device_map  # pylint: disable=protected-access
        return values.regroup(device_map, replicate_outputs)
Пример #37
0
  def experimental_run_v2(self, fn, args=(), kwargs=None):
    """See base class."""
    if context.executing_eagerly() and not ops.inside_function():
      raise NotImplementedError(
          "Eager mode not supported in TPUStrategy outside TF functions.")

    if kwargs is None:
      kwargs = {}

    # Used to re-structure flattened output tensors from `tpu.replicate()`
    # into a structured format.
    result = [[]]

    def replicated_fn(replica_id, replica_args, replica_kwargs):
      """Wraps user function to provide replica ID and `Tensor` inputs."""
      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
        result[0] = fn(*replica_args, **replica_kwargs)
      return result[0]

    replicate_inputs = []  # By replica.
    for i in range(self.num_replicas_in_sync):
      replicate_inputs.append(
          [constant_op.constant(i, dtype=dtypes.int32),
           values.select_replica(i, args),
           values.select_replica(i, kwargs)])

    with self.scope():
      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

    # Remove all no ops that may have been added during 'tpu.replicate()'
    if isinstance(result[0], list):
      result[0] = [
          output for output in result[0] if tensor_util.is_tensor(output)
      ]

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_output))
        for replica_output in replicate_outputs
    ]

    device_map = self.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
Пример #38
0
  def testSameId(self):
    foo = object()
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map, (("a", foo), ("b", foo)))
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_replica(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_replica(), should undo the merge done by regroup().
    result_0 = values.select_replica(0, result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_replica(1, result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
Пример #39
0
  def testSameId(self):
    foo = object()
    device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
    result = values.regroup(device_map, (("a", foo), ("b", foo)))
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_replica(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_replica(), should undo the merge done by regroup().
    result_0 = values.select_replica(0, result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_replica(1, result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
  def batch_reduce(self, reduce_op, value_destination_pairs):
    """Reduce PerReplica objects in a batch.

    Reduce each first element in `value_destination_pairs` to each second
    element which indicates the destinations.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
        the `per_replica_value` will be reduced.
      value_destination_pairs: a list or a tuple of PerReplica objects
        (or tensors with device set if there is one device) and destinations.

    Returns:
      a list of Mirrored objects.

    Raises:
      ValueError: if `value_destination_pairs` is not an iterable of
        tuples of PerReplica objects and destinations.
    """
    # TODO(yuefengz): if destinations are different, split into several
    # `_batch_reduce` invocations.
    if not _validate_value_destination_pairs(value_destination_pairs):
      # If the first element of each pair is a tensor, we try to turn it into a
      # PerReplica object.
      value_destination_pairs = _normalize_value_destination_pairs(
          value_destination_pairs)

    for _, d in value_destination_pairs:
      validate_destinations(d)

    # Shortcut all PerReplica objects only contain one value.
    if self._num_between_graph_workers == 1 and _all_devices_match(
        value_destination_pairs) and len(
            value_destination_pairs[0][0].values) == 1:
      return [
          value_lib.regroup(
              v.device_map, v.values, wrap_class=value_lib.Mirrored)
          for v, _ in value_destination_pairs
      ]

    return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
Пример #41
0
  def experimental_run(self, fn, input_iterator=None):
    """See base class."""
    if context.executing_eagerly():
      raise NotImplementedError("Eager mode not supported in TPUStrategy.")

    if self.extended._disable_training_loop_on_host:  # pylint: disable=protected-access
      raise NotImplementedError(
          "`experimental_run` is not compatible with "
          "`_disable_training_loop_on_host=True`")

    if input_iterator is None:
      inputs = []
    else:
      inputs = input_iterator.get_next()

    result = [None]
    def replicated_fn(replica_id, inputs):
      """Wraps user function to provide replica ID and `Tensor` inputs."""
      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
        if input_iterator is None:
          result[0] = fn()
        else:
          result[0] = fn(inputs)
      return result[0]

    replicate_inputs = []  # By replica.
    for i in range(self.num_replicas_in_sync):
      replicate_inputs.append(
          [constant_op.constant(i, dtype=dtypes.int32),
           values.select_replica(i, inputs)])

    with self.scope():
      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
        for replica_outputs in replicate_outputs]

    device_map = self.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
Пример #42
0
  def get_next(self, name=None):
    """Returns the next input from the iterator for all replicas."""
    if not self._enable_get_next_as_optional:
      replicas = []
      for i, worker in enumerate(self._input_workers.worker_devices):
        if name is not None:
          d = tf_device.DeviceSpec.from_string(worker)
          new_name = "%s_%s_%d" % (name, d.job, d.task)
        else:
          new_name = None
        with ops.device(worker):
          # Make `replicas` a flat list of values across all replicas.
          replicas.extend(
              self._iterators[i].get_next_as_list_deprecated(new_name))
      return values.regroup(self._input_workers.device_map, replicas)

    replicas = []
    worker_has_values = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      if name is not None:
        d = tf_device.DeviceSpec.from_string(worker)
        new_name = "%s_%s_%d" % (name, d.job, d.task)
      else:
        new_name = None
      with ops.device(worker):
        worker_has_value, next_element = (
            self._iterators[i].get_next_as_list(new_name))
        worker_has_values.append(worker_has_value)
        # Make `replicas` a flat list of values across all replicas.
        replicas.append(next_element)

    out_of_range_replicas = []

    def out_of_range_fn(worker_index, device):
      """This function will throw an OutOfRange error."""
      # As this will be only called when there is no data left, so calling
      # get_next() will trigger an OutOfRange error.
      data = self._iterators[worker_index].get_next(device)
      out_of_range_replicas.append(data)
      return data

    # `global_has_value` indicates whether there is data in this global batch.
    # We do a all-reduce across all the workers in the multi-worker case.
    # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
    if len(worker_has_values) > 1:
      with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
        # Place the tf.reduce_any op in device 0 to minimize communication
        # cost.
        # TODO(b/128545270): Investigate why placing it on worker 0 will cause
        # the entire data to copy back from device to host.
        global_has_value = math_ops.reduce_any(worker_has_values)
    else:
      global_has_value = worker_has_values[0]

    results = []
    for i, worker in enumerate(self._input_workers.worker_devices):
      with ops.device(worker):
        devices = self._input_workers.compute_devices_for_worker(i)
        for j, device in enumerate(devices):
          with ops.device(device):
            # pylint: disable=undefined-loop-variable
            # pylint: disable=cell-var-from-loop
            # It is fine for the lambda to capture variables from the loop as
            # the lambda is executed in the loop as well.
            result = control_flow_ops.cond(global_has_value,
                                           lambda: replicas[i][j],
                                           lambda: out_of_range_fn(i, device))
            # pylint: enable=cell-var-from-loop
            # pylint: enable=undefined-loop-variable
            results.append(result)
    replicas = results

    # Some dimensions in `replicas` will become unknown after we conditionally
    # return the real tensors or the dummy tensors. We fix the input shapes by
    # using the shapes from `out_of_range_replicas` because it is calling
    # get_next() inside.
    flattened_replicas = nest.flatten(replicas)
    for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
      flattened_replicas[i].set_shape(replica_data.get_shape())
    replicas = nest.pack_sequence_as(replicas, flattened_replicas)

    return values.regroup(self._input_workers.device_map, replicas)
Пример #43
0
 def get_next(self, name=None):
   assert self._input_workers.num_workers == 1
   data_list = self.get_next_as_list(name)
   return values.regroup(self._input_workers.device_map, data_list)
Пример #44
0
 def testMirroredContainer(self):
   if context.num_gpus() < 1 and context.executing_eagerly():
     self.skipTest("A GPU is not available for this test in eager mode.")
   v, device_map, mirrored = _make_mirrored()
   result = values.regroup(device_map, v)
   self.assertIs(mirrored, result)
Пример #45
0
def _call_for_each_replica(distribution, fn, args, kwargs):
  """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
  # TODO(josh11b): Add this option once we add synchronization to variable
  # creation. Until then, this is pretty unsafe to use.
  run_concurrently = False
  if not context.executing_eagerly():
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.extended.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredExtended._MirroredReplicaThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

  # When `fn` starts `should_run` event is set on _MirroredReplicaThread
  # (`MRT`) threads. The execution waits until
  # `MRT.has_paused` is set, which indicates that either `fn` is
  # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
  # complete, then `MRT.done` is set to True.  Otherwise, arguments
  # of `get_replica_context().merge_call` from all paused threads are grouped
  # and the `merge_fn` is performed.  Results of the
  # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
  # Each such `get_replica_context().merge_call` call returns the
  # `MRT.merge_result` for that thread when `MRT.should_run` event
  # is reset again. Execution of `fn` resumes.

  try:
    with coord.stop_on_exception():
      all_done = False
      while not all_done and not coord.should_stop():
        done = []
        if run_concurrently:
          for t in threads:
            t.should_run.set()
          for t in threads:
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        else:
          for t in threads:
            t.should_run.set()
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        if coord.should_stop():
          return None
        all_done = all(done)
        if not all_done:
          if any(done):
            raise RuntimeError("Some replicas made a different number of "
                               "replica_context().merge_call() calls.")
          # get_replica_context().merge_call() case
          merge_args = values.regroup({t.device: t.merge_args for t in threads})
          merge_kwargs = values.regroup(
              {t.device: t.merge_kwargs for t in threads})
          # We capture the name_scope of the MRT when we call merge_fn
          # to ensure that if we have opened a name scope in the MRT,
          # it will be respected when executing the merge function. We only
          # capture the name_scope from the first MRT and assume it is
          # the same for all other MRTs.
          mtt_captured_name_scope = threads[0].captured_name_scope
          with ops.name_scope(mtt_captured_name_scope):
            merge_result = threads[0].merge_fn(distribution, *merge_args,
                                               **merge_kwargs)
          for t in threads:
            t.merge_result = values.select_device(t.device, merge_result)
  finally:
    for t in threads:
      t.should_run.set()
    coord.join(threads)

  return values.regroup({t.device: t.main_result for t in threads})