Ejemplo n.º 1
0
  def testWrapClass(self):
    # Normally a mirrored value would be the same across devices, but
    # for a test it is convenient to be able to tell the values apart.
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")},
                            values.Mirrored)
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
    self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
    self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
    self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # Values are marked as mirrored, so select_device_mirrored() is allowed.
    self.assertEqual(_nested_value("1"),
                     values.select_device_mirrored(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device_mirrored(_device_str(1), result))
Ejemplo n.º 2
0
  def testNested(self):
    result = values.regroup({_device_str(0): _nested_value("1"),
                             _device_str(1): _nested_value("2")})
    self.assertIsInstance(result, tuple)
    self.assertEqual(3, len(result))
    self._is_per_device(result[0], ["a1", "a2"])
    self._is_per_device(result[2], ["h1", "h2"])

    self.assertIsInstance(result[1], list)
    self.assertEqual(3, len(result[1]))
    self._is_per_device(result[1][0], ["b1", "b2"])
    self._is_per_device(result[1][2], ["g1", "g2"])

    self.assertIsInstance(result[1][1], dict)
    self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
    self._is_per_device(result[1][1]["c"], ["d1", "d2"])
    self._is_per_device(result[1][1]["e"], ["f1", "f2"])

    # Also test that we can undo the merge using select_device()
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))
    self.assertEqual(_nested_value("2"),
                     values.select_device(_device_str(1), result))
    # select_device_mirrored() should fail due to non-mirrored values
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(0), result)
    with self.assertRaises(TypeError):
      values.select_device_mirrored(_device_str(1), result)
Ejemplo n.º 3
0
    def testWrapClass(self):
        # Normally a mirrored value would be the same across devices, but
        # for a test it is convenient to be able to tell the values apart.
        result = values.regroup(
            {
                _device_str(0): _nested_value("1"),
                _device_str(1): _nested_value("2")
            }, values.Mirrored)
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
        self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
        self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
        self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # Values are marked as mirrored, so select_device_mirrored() is allowed.
        self.assertEqual(_nested_value("1"),
                         values.select_device_mirrored(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device_mirrored(_device_str(1), result))
Ejemplo n.º 4
0
    def testNested(self):
        result = values.regroup({
            _device_str(0): _nested_value("1"),
            _device_str(1): _nested_value("2")
        })
        self.assertIsInstance(result, tuple)
        self.assertEqual(3, len(result))
        self._is_per_device(result[0], ["a1", "a2"])
        self._is_per_device(result[2], ["h1", "h2"])

        self.assertIsInstance(result[1], list)
        self.assertEqual(3, len(result[1]))
        self._is_per_device(result[1][0], ["b1", "b2"])
        self._is_per_device(result[1][2], ["g1", "g2"])

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_device(result[1][1]["c"], ["d1", "d2"])
        self._is_per_device(result[1][1]["e"], ["f1", "f2"])

        # Also test that we can undo the merge using select_device()
        self.assertEqual(_nested_value("1"),
                         values.select_device(_device_str(0), result))
        self.assertEqual(_nested_value("2"),
                         values.select_device(_device_str(1), result))
        # select_device_mirrored() should fail due to non-mirrored values
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(0), result)
        with self.assertRaises(TypeError):
            values.select_device_mirrored(_device_str(1), result)
Ejemplo n.º 5
0
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Ejemplo n.º 6
0
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Ejemplo n.º 7
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Ejemplo n.º 8
0
 def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   should_group = options.pop("grouped")
   assert not options  # Validate that we are processing all of the options.
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, should_group)
Ejemplo n.º 9
0
 def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   should_group = options.pop("grouped")
   assert not options  # Validate that we are processing all of the options.
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, should_group)
Ejemplo n.º 10
0
 def _update(self, var, fn, args, kwargs, group):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             # If args and kwargs are not mirrored, the value is returned as is.
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.update_regroup(self, updates, group)
Ejemplo n.º 11
0
 def _update(self, var, fn, *args, **kwargs):
     # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
     # kwargs don't need to be mirrored.
     assert isinstance(var, values.MirroredVariable)
     # TODO(josh11b): In eager mode, use one thread per device.
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.regroup(updates, values.Mirrored)
Ejemplo n.º 12
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
   # kwargs don't need to be mirrored.
   assert isinstance(var, values.MirroredVariable)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
Ejemplo n.º 13
0
 def _update(self, var, options, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   should_group = options.pop("grouped")
   assert not options  # Validate that we are processing all of the options.
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, should_group)
Ejemplo n.º 14
0
 def map(self, map_over, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   index = {}
   for i, m in enumerate(map_over):
     d = self._devices[i % len(self._devices)]
     with ops.device(d):
       l = index.get(d, [])
       l.append(fn(m,
                   *values.select_device_mirrored(d, args),
                   **values.select_device_mirrored(d, kwargs)))
       index[d] = l
   # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
   # in addition to PerDevice data.
   return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
Ejemplo n.º 15
0
 def map(self, map_over, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   index = {}
   for i, m in enumerate(map_over):
     d = self._devices[i % len(self._devices)]
     with ops.device(d):
       l = index.get(d, [])
       l.append(fn(m,
                   *values.select_device_mirrored(d, args),
                   **values.select_device_mirrored(d, kwargs)))
       index[d] = l
   # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
   # in addition to PerDevice data.
   return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
Ejemplo n.º 16
0
 def map(self, map_over, fn, *args, **kwargs):
     # TODO (josh11b): In eager mode, use one thread per device. id:1098
     # https://github.com/imdone/tensorflow/issues/1099
     index = {}
     i = 0
     for m in map_over:
         d = self._devices[i % len(self._devices)]
         with ops.device(d):
             l = index.get(d, [])
             l.append(
                 fn(m, *values.select_device_mirrored(d, args),
                    **values.select_device_mirrored(d, kwargs)))
             index[d] = l
     # TODO (josh11b): Need a values.regroup equivalent that handles MapOutput id:1079
     # https://github.com/imdone/tensorflow/issues/1080
     # in addition to PerDevice data.
     return values.PerDevice(
         {k: values.MapOutput(v)
          for k, v in index.items()})
Ejemplo n.º 17
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if group:
                return fn(var, *args, **kwargs)
            else:
                return [fn(var, *args, **kwargs)]

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))
        return values.update_regroup(self, updates, group)
Ejemplo n.º 18
0
  def _update(self, var, options, fn, *args, **kwargs):
    assert isinstance(var, values.TPUMirroredVariable)
    should_group = options.pop("grouped")
    assert not options  # Validate that we are processing all of the options.

    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if should_group:
        return fn(var, *args, **kwargs)
      else:
        return [fn(var, *args, **kwargs)]

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))
    return values.update_regroup(self, updates, should_group)
Ejemplo n.º 19
0
  def _update(self, var, fn, *args, **kwargs):
    # TODO(jhseu): Consider supporting grouped==False.
    assert isinstance(var, values.TPUMirroredVariable)
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      return fn(var, *args, **kwargs)

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))

    # Make a single control dependency to keep the variables mirrored. If one
    # assignment is fetched, then run all assignments.
    sorted_keys = sorted(updates.keys())
    update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
    for i, d in enumerate(sorted_keys):
      updates[d] = update_tuple[i]
    return values.regroup(updates, values.Mirrored)