Esempio n. 1
0
    def testNested(self):
        result = values.regroup({
            _device_str(0): _nested_value("1"),
            _device_str(1): _nested_value("2")
        })
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"])
        self._is_per_device(result[2], ["h1", "h2"])

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"])
        self._is_per_device(result[1][2], ["g1", "g2"])

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"])
        self._is_per_device(result[1][1]["e"], ["f1", "f2"])

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # select_device_mirrored() should fail due to non-mirrored values
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(0), result)
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(1), result)
Esempio n. 2
0
    def _update(self, var, fn, *args, **kwargs):
        # TODO(jhseu): Consider supporting grouped==False.
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            return fn(var, *args, **kwargs)

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))

        # Make a single control dependency to keep the variables mirrored. If one
        # assignment is fetched, then run all assignments.
        sorted_keys = sorted(updates.keys())
        update_tuple = control_flow_ops.tuple(
            [updates[d] for d in sorted_keys])
        for i, d in enumerate(sorted_keys):
            updates[d] = update_tuple[i]
        return values.regroup(updates, values.Mirrored)
Esempio n. 3
0
    def testWrapClass(self):
        # Normally a mirrored value would be the same across devices, but
        # for a test it is convenient to be able to tell the values apart.
        result = values.regroup(
            {
                _device_str(0): _nested_value("1"),
                _device_str(1): _nested_value("2")
            }, values.Mirrored)
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
        self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
        self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
        self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # Values are marked as mirrored, so select_device_mirrored() is allowed.
        self.assertEqual(_nested_value("1"),
                         values.select_device_mirrored(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device_mirrored(_device_str(1), result))
Esempio n. 4
0
 def testMirroredContainer(self):
     if context.num_gpus() < 1 and context.executing_eagerly():
         self.skipTest(
             "A GPU is not available for this test in eager mode.")
     v, devices, mirrored = _make_mirrored()
     result = values.regroup(dict(zip(devices, v)))
     self.assertIs(mirrored, result)
Esempio n. 5
0
  def testNested(self):
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")})
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"])
    self._is_per_device(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"])
    self._is_per_device(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"])
    self._is_per_device(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
Esempio n. 6
0
    def testNamedTupleEstimatorSpec(self):
        with context.graph_mode(), ops.Graph().as_default():
            created_estimator_specs = {}
            to_regroup = {}

            for device_id in range(3):
                spec = model_fn_lib.EstimatorSpec(
                    mode=model_fn_lib.ModeKeys.TRAIN,
                    loss=constant_op.constant(device_id / 2),
                    train_op=array_ops.identity(
                        constant_op.constant(device_id)))
                created_estimator_specs[device_id] = spec
                to_regroup[_device_str(device_id)] = spec

            merged_estimator_spec = values.regroup(to_regroup)

            self.assertTrue(
                isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
            self.assertEquals(model_fn_lib.ModeKeys.TRAIN,
                              merged_estimator_spec.mode)
            for device_id in range(3):
                d = _device_str(device_id)
                self.assertEquals(created_estimator_specs[device_id].loss,
                                  merged_estimator_spec.loss.get(d))
                self.assertEquals(created_estimator_specs[device_id].train_op,
                                  merged_estimator_spec.train_op.get(d))
                # Scaffold is populated by `EstimatorSpec.__new__`.
                self.assertEquals(created_estimator_specs[device_id].scaffold,
                                  merged_estimator_spec.scaffold.get(d))
                # Also test that we can undo the merge using select_device()
                self.assertEquals(
                    created_estimator_specs[device_id],
                    values.select_device(_device_str(device_id),
                                         merged_estimator_spec))
Esempio n. 7
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Esempio n. 8
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Esempio n. 9
0
  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      created_estimator_specs = {}
      to_regroup = {}

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        created_estimator_specs[device_id] = spec
        to_regroup[_device_str(device_id)] = spec

      merged_estimator_spec = values.regroup(to_regroup)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEquals(created_estimator_specs[device_id].loss,
                          merged_estimator_spec.loss.get(d))
        self.assertEquals(created_estimator_specs[device_id].train_op,
                          merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEquals(created_estimator_specs[device_id].scaffold,
                          merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_device()
        self.assertEquals(created_estimator_specs[device_id],
                          values.select_device(_device_str(device_id),
                                               merged_estimator_spec))
Esempio n. 10
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")},
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
Esempio n. 11
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self.unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, wrap them in a Mirrored
      # container, else in a PerDevice container.
      if aggregation is variables_lib.VariableAggregation.NONE:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Esempio n. 12
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self.unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, wrap them in a Mirrored
      # container, else in a PerDevice container.
      if aggregation is variables_lib.VariableAggregation.NONE:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Esempio n. 13
0
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Esempio n. 14
0
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Esempio n. 15
0
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = values.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self.unwrap(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been aggregated, wrap them in a Mirrored
            # container, else in a PerDevice container.
            if aggregation is variables_lib.VariableAggregation.NONE:
                last_step_tensor_outputs_dict[name] = values.regroup(
                    {d: t
                     for d, t in zip(self._devices, output)}, values.PerDevice)
            else:
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
Esempio n. 16
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Esempio n. 17
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Esempio n. 18
0
 def _update(self, var, fn, *args, **kwargs):
     # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
     # kwargs don't need to be mirrored.
     assert isinstance(var, values.MirroredVariable)
     # TODO(josh11b): In eager mode, use one thread per device.
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.regroup(updates, values.Mirrored)
Esempio n. 19
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
   # kwargs don't need to be mirrored.
   assert isinstance(var, values.MirroredVariable)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Esempio n. 20
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self.unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, wrap them in a Mirrored
      # container, else in a PerDevice container.
      if aggregation is variables_lib.VariableAggregation.NONE:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Esempio n. 21
0
  def testSameId(self):
    foo = object()
    result = values.regroup({_device_str(0): ("a", foo),
                             _device_str(1): ("b", foo)})
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_device(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_device(), should undo the merge done by regroup().
    result_0 = values.select_device(_device_str(0), result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_device(_device_str(1), result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
Esempio n. 22
0
  def testSameId(self):
    foo = object()
    result = values.regroup({_device_str(0): ("a", foo),
                             _device_str(1): ("b", foo)})
    self.assertIsInstance(result, tuple)
    self.assertEqual(2, len(result))
    self._is_per_device(result[0], ["a", "b"])
    self.assertIs(foo, result[1])

    # Test select_device(), should undo the merge done by regroup().
    result_0 = values.select_device(_device_str(0), result)
    self.assertIsInstance(result_0, tuple)
    self.assertEqual(2, len(result_0))
    self.assertEqual("a", result_0[0])
    self.assertIs(foo, result_0[1])
    result_1 = values.select_device(_device_str(1), result)
    self.assertIsInstance(result_1, tuple)
    self.assertEqual(2, len(result_1))
    self.assertEqual("b", result_1[0])
    self.assertIs(foo, result_1[1])
Esempio n. 23
0
  def _update(self, var, fn, *args, **kwargs):
    # TODO(jhseu): Consider supporting grouped==False.
    assert isinstance(var, values.TPUMirroredVariable)
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      return fn(var, *args, **kwargs)

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))

    # Make a single control dependency to keep the variables mirrored. If one
    # assignment is fetched, then run all assignments.
    sorted_keys = sorted(updates.keys())
    update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
    for i, d in enumerate(sorted_keys):
      updates[d] = update_tuple[i]
    return values.regroup(updates, values.Mirrored)
def _call_for_each_replica(distribution, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}

    # TODO(isaprykin): Create these threads once instead of during every run()
    # call.
    threads = []
    for index, d in enumerate(distribution.worker_devices):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = MirroredStrategy._MirroredReplicaThread(  # pylint: disable=protected-access
            distribution, coord, d, variable_creator_fn, fn,
            *values.select_device(d, args), **values.select_device(d, kwargs))
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = values.regroup(
                        {t.device: t.merge_args
                         for t in threads})
                    merge_kwargs = values.regroup(
                        {t.device: t.merge_kwargs
                         for t in threads})
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    with ops.name_scope(mtt_captured_name_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for t in threads:
                        t.merge_result = values.select_device(
                            t.device, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return values.regroup({t.device: t.main_result for t in threads})
Esempio n. 25
0
 def testMirroredContainer(self):
   if context.num_gpus() < 1 and context.executing_eagerly():
     self.skipTest("A GPU is not available for this test in eager mode.")
   v, devices, mirrored = _make_mirrored()
   result = values.regroup(dict(zip(devices, v)))
   self.assertIs(mirrored, result)
Esempio n. 26
0
def _call_for_each_tower(distribution, fn, *args, **kwargs):
  """Run `fn` in separate threads, once per tower/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per device, each in its own thread).
    *args: positional arguments for `fn`
    **kwargs: keyword arguments for `fn`.
        `"run_concurrently"`: Boolean indicating whether executions of `fn`
           can be run concurrently (under eager execution only), defaults to
           `True`.

  Returns:
    Merged return value of `fn` across all towers.

  Raises:
    RuntimeError: If fn() calls get_tower_context().merge_call() a different
        number of times from the available devices.
  """
  run_concurrently = kwargs.pop("run_concurrently", True)
  if not context.executing_eagerly():
    # Lots of TF library code isn't thread-safe in graph mode, and
    # there is little to be gained by turning on multithreading when
    # constructing a graph.
    run_concurrently = False
    # Needed for per-thread device, etc. contexts in graph mode.
    ops.get_default_graph().switch_to_thread_local()
  elif run_concurrently is None:
    run_concurrently = True

  coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))

  shared_variable_store = {}

  # TODO(isaprykin): Create these threads once instead of during every run()
  # call.
  threads = []
  for index, d in enumerate(distribution.worker_devices):
    variable_creator_fn = shared_variable_creator.make_fn(
        shared_variable_store, index)
    t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
        distribution, coord, d, variable_creator_fn, fn,
        *values.select_device(d, args), **values.select_device(d, kwargs))
    threads.append(t)

  for t in threads:
    t.start()

  # When `fn` starts `should_run` event is set on _MirroredTowerThread
  # (`MTT`) threads. The execution waits until
  # `MTT.has_paused` is set, which indicates that either `fn` is
  # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
  # complete, then `MTT.done` is set to True.  Otherwise, arguments
  # of `get_tower_context().merge_call` from all paused threads are grouped
  # and the `merge_fn` is performed.  Results of the
  # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
  # Each such `get_tower_context().merge_call` call returns the
  # `MTT.merge_result` for that thread when `MTT.should_run` event
  # is reset again. Execution of `fn` resumes.

  try:
    with coord.stop_on_exception():
      all_done = False
      while not all_done and not coord.should_stop():
        done = []
        if run_concurrently:
          for t in threads:
            t.should_run.set()
          for t in threads:
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        else:
          for t in threads:
            t.should_run.set()
            t.has_paused.wait()
            t.has_paused.clear()
            if coord.should_stop():
              return None
            done.append(t.done)
        if coord.should_stop():
          return None
        all_done = all(done)
        if not all_done:
          if any(done):
            raise RuntimeError("Some towers made a different number of "
                               "tower_context().merge_call() calls.")
          # get_tower_context().merge_call() case
          merge_args = values.regroup({t.device: t.merge_args for t in threads})
          merge_kwargs = values.regroup(
              {t.device: t.merge_kwargs for t in threads})
          # We capture the name_scope of the MTT when we call merge_fn
          # to ensure that if we have opened a name scope in the MTT,
          # it will be respected when executing the merge function. We only
          # capture the name_scope from the first MTT and assume it is
          # the same for all other MTTs.
          mtt_captured_name_scope = threads[0].captured_name_scope
          with ops.name_scope(mtt_captured_name_scope):
            merge_result = threads[0].merge_fn(distribution, *merge_args,
                                               **merge_kwargs)
          for t in threads:
            t.merge_result = values.select_device(t.device, merge_result)
  finally:
    for t in threads:
      t.should_run.set()
    coord.join(threads)

  return values.regroup({t.device: t.main_result for t in threads})
Esempio n. 27
0
    def _call_for_each_tower(self, fn, *args, **kwargs):
        """Run `fn` in separate threads, once per tower/worker device.

    Args:
      fn: function to run (will be run once per device, each in its own thread).
      *args: positional arguments for `fn`
      **kwargs: keyword arguments for `fn`.
          `"run_concurrently"`: Boolean indicating whether executions of `fn`
             can be run concurrently (under eager execution only), defaults to
             `True`.

    Returns:
      Merged return value of `fn` across all towers.

    Raises:
      RuntimeError: If fn() calls get_tower_context().merge_call() a different
          number of times for when called for different devices.
    """
        run_concurrently = kwargs.pop("run_concurrently", True)
        if not context.executing_eagerly():
            # Lots of TF library code isn't thread-safe in graph mode, and
            # there is little to be gained by turning on multithreading when
            # constructing a graph.
            run_concurrently = False
            # Needed for per-thread device, etc. contexts in graph mode.
            ops.get_default_graph().switch_to_thread_local()
        elif run_concurrently is None:
            run_concurrently = True

        coord = coordinator.Coordinator(
            clean_stop_exception_types=(_RequestedStop, ))

        shared_variable_store = {}

        # TODO(isaprykin): Create these threads once instead of during every run()
        # call.
        threads = []
        for index, d in enumerate(self._devices):
            variable_creator_fn = shared_variable_creator.make_fn(
                shared_variable_store, index)
            t = MirroredStrategy._MirroredTowerThread(
                self, coord, d, variable_creator_fn, fn,
                *values.select_device(d, args),
                **values.select_device(d, kwargs))
            threads.append(t)

        for t in threads:
            t.start()

        # When `fn` starts `should_run` event is set on _MirroredTowerThread
        # (`MTT`) threads. The execution waits until
        # `MTT.has_paused` is set, which indicates that either `fn` is
        # complete or a `get_tower_context().merge_call()` is called.  If `fn` is
        # complete, then `MTT.done` is set to True.  Otherwise, arguments
        # of `get_tower_context().merge_call` from all paused threads are grouped
        # and the `merge_fn` is performed.  Results of the
        # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
        # Each such `get_tower_context().merge_call` call returns the
        # `MTT.merge_result` for that thread when `MTT.should_run` event
        # is reset again. Execution of `fn` resumes.

        try:
            with coord.stop_on_exception():
                all_done = False
                while not all_done and not coord.should_stop():
                    done = []
                    if run_concurrently:
                        for t in threads:
                            t.should_run.set()
                        for t in threads:
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    else:
                        for t in threads:
                            t.should_run.set()
                            t.has_paused.wait()
                            t.has_paused.clear()
                            if coord.should_stop():
                                return None
                            done.append(t.done)
                    if coord.should_stop():
                        return None
                    all_done = all(done)
                    if not all_done:
                        if any(done):
                            raise RuntimeError(
                                "Some towers made a different number of "
                                "tower_context().merge_call() calls.")
                        # get_tower_context().merge_call() case
                        merge_args = values.regroup(
                            {t.device: t.merge_args
                             for t in threads})
                        merge_kwargs = values.regroup(
                            {t.device: t.merge_kwargs
                             for t in threads})
                        merge_result = threads[0].merge_fn(
                            self, *merge_args, **merge_kwargs)
                        for t in threads:
                            t.merge_result = values.select_device(
                                t.device, merge_result)
        finally:
            for t in threads:
                t.should_run.set()
            coord.join(threads)

        return values.regroup({t.device: t.main_result for t in threads})