def _update(self, var, fn, args, kwargs, group): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) if tpu_values.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: return (fn(var, *args, **kwargs), ) # Otherwise, we revert to MirroredStrategy behavior and update the variable # on each replica directly. updates = [] values_and_devices = [] packed_var = var._packed_variable # pylint: disable=protected-access if packed_var is not None: for device in packed_var.devices: values_and_devices.append((packed_var, device)) else: for value in var.values: values_and_devices.append((value, value.device)) for i, value_and_device in enumerate(values_and_devices): value = value_and_device[0] device = value_and_device[1] name = "update_%d" % i with ops.device(device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(value, *distribute_utils.select_replica_mirrored(i, args), **distribute_utils.select_replica_mirrored(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = distribute_utils.regroup( (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertLen(result[1], 3) self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), distribute_utils.select_replica_mirrored(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica_mirrored(1, result))
def testNested(self): result = distribute_utils.regroup( (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"]) self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertLen(result[1], 3) self._is_per_replica(result[1][0], ["b1", "b2"]) self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): distribute_utils.select_replica_mirrored(0, result) with self.assertRaises(TypeError): distribute_utils.select_replica_mirrored(1, result)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group): assert isinstance(colocate_with, tuple) # TODO(josh11b): In eager mode, use one thread per device. updates = [] for i, d in enumerate(colocate_with): name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): updates.append( fn(*distribute_utils.select_replica_mirrored(i, args), **distribute_utils.select_replica_mirrored(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(v, *distribute_utils.select_replica_mirrored(i, args), **distribute_utils.select_replica_mirrored(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) if tpu_values.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: return (fn(var, *args, **kwargs), ) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(v, *distribute_utils.select_replica_mirrored(i, args), **distribute_utils.select_replica_mirrored(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)