def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = values.regroup({_device_str(0): _nested_value("1"), _device_str(1): _nested_value("2")}, values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def testNested(self): result = values.regroup({_device_str(0): _nested_value("1"), _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"]) self._is_per_device(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"]) self._is_per_device(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"]) self._is_per_device(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = values.regroup( { _device_str(0): _nested_value("1"), _device_str(1): _nested_value("2") }, values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def testNested(self): result = values.regroup({ _device_str(0): _nested_value("1"), _device_str(1): _nested_value("2") }) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"]) self._is_per_device(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"]) self._is_per_device(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"]) self._is_per_device(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def _update_non_slot(self, colocate_with, fn, *args, **kwargs): assert isinstance(colocate_with, list) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): assert isinstance(colocate_with, list) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, group)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and # kwargs don't need to be mirrored. assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and # kwargs don't need to be mirrored. assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, options, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) l.append(fn(m, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs))) index[d] = l # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput # in addition to PerDevice data. return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
def map(self, map_over, fn, *args, **kwargs): # TODO (josh11b): In eager mode, use one thread per device. id:1098 # https://github.com/imdone/tensorflow/issues/1099 index = {} i = 0 for m in map_over: d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) l.append( fn(m, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs))) index[d] = l # TODO (josh11b): Need a values.regroup equivalent that handles MapOutput id:1079 # https://github.com/imdone/tensorflow/issues/1080 # in addition to PerDevice data. return values.PerDevice( {k: values.MapOutput(v) for k, v in index.items()})
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, group)
def _update(self, var, options, fn, *args, **kwargs): assert isinstance(var, values.TPUMirroredVariable) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if should_group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def _update(self, var, fn, *args, **kwargs): # TODO(jhseu): Consider supporting grouped==False. assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access return fn(var, *args, **kwargs) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) # Make a single control dependency to keep the variables mirrored. If one # assignment is fetched, then run all assignments. sorted_keys = sorted(updates.keys()) update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) for i, d in enumerate(sorted_keys): updates[d] = update_tuple[i] return values.regroup(updates, values.Mirrored)