def __init__(self, strategy, values, aggregation): self._distribute_strategy = strategy self._aggregation = aggregation super(DistributedVariable, self).__init__(values) self._common_name = self._primary.name.split(":")[0] # Packed variable is used to reduce the overhead of function execution. # For a DistributedVariable, only one variable handle is captured into a # function graph. It's only supported in eager mode. if ops.executing_eagerly_outside_functions() and getattr( strategy, "_enable_packed_variable_in_eager_mode", False): name = "%s/packed/" % self._common_name self._packed_var = packed.PackedDistributedVariable(values, name=name) else: self._packed_var = None # tf.keras keeps track of variables initialized using this attribute. When # tf.keras gets the default session, it initializes all uninitialized vars. # We need to make _keras_initialized a member of DistributedVariable because # without this it will use `__getattr__` which will delegate to a component # variable. self._keras_initialized = False # Typically, a `DistributedVariable`'s initializer is composed of the # initializers of the components variables. However, in some cases, such as # when restoring from a checkpoint, we may set the _initializer_op # property on the entire `DistributedVariable`. self._initializer_op = None
def testPackedVariable(self): with ops.device('/cpu:0'): v0 = resource_variable_ops.ResourceVariable(1.0, name='var0') with ops.device('/cpu:1'): v1 = resource_variable_ops.ResourceVariable(2.0, name='var1') packed_var = packed_distributed_variable.PackedDistributedVariable( [v0, v1]) self.assertTrue(packed_var.handle.is_packed) self.assertTrue(packed_var.is_initialized) with ops.device('/cpu:0'): self.assertAllEqual(packed_var.get_var_on_current_device(), v0) val0 = packed_var.assign(2.0).assign_add(1.0) self.assertAllEqual(val0, 3.0) with ops.device('/cpu:1'): self.assertAllEqual(packed_var.get_var_on_current_device(), v1) val0 = packed_var.assign(2.0).assign_add(1.0) self.assertAllEqual(val0, 3.0) @def_function.function def update_var(): with ops.device('/cpu:0'): packed_var.assign_add(3.0).assign_sub(1.0) read0 = packed_var.value() with ops.device('/cpu:1'): packed_var.assign_sub(4.0).assign_sub(2.0) read1 = packed_var.value() return read0, read1 self.assertAllEqual(update_var(), (5.0, -3.0))
def testPackedVarAndDevice(self): device0 = device_util.canonicalize('/cpu:0') device1 = device_util.canonicalize('/cpu:1') with ops.device(device0): v0 = resource_variable_ops.ResourceVariable(1.0) with ops.device(device1): v1 = resource_variable_ops.ResourceVariable(2.0) packed_var = packed_distributed_variable.PackedDistributedVariable( [v0, v1]) packed_var0 = packed_distributed_variable.PackedVarAndDevice( packed_var, device0) self.assertTrue(packed_var0.handle.is_packed) self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0) packed_var1 = packed_distributed_variable.PackedVarAndDevice( packed_var, device1) self.assertAllEqual(packed_var1.assign(3.0), 3.0) @def_function.function def func(): var0 = packed_distributed_variable.PackedVarAndDevice( packed_var, device0) var0.assign_add(3.0) var1 = packed_distributed_variable.PackedVarAndDevice( packed_var, device1) return var0.value(), math_ops.add(var1, 2.0) self.assertAllEqual(func(), (4.0, 5.0))
def testNoGarbage(self): device0 = device_util.canonicalize('/cpu:0') device1 = device_util.canonicalize('/cpu:1') with ops.device(device0): v0 = resource_variable_ops.ResourceVariable(1.0) with ops.device(device1): v1 = resource_variable_ops.ResourceVariable(2.0) packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) # This needs a workaround to avoid creating reference cycles if the # attribute doesn't exist. hasattr(packed_var.on_device('/cpu:0'), 'nonexist')