Beispiel #1
0
    def __init__(self, strategy, values, aggregation):
        self._distribute_strategy = strategy
        self._aggregation = aggregation
        super(DistributedVariable, self).__init__(values)
        self._common_name = self._primary.name.split(":")[0]

        # Packed variable is used to reduce the overhead of function execution.
        # For a DistributedVariable, only one variable handle is captured into a
        # function graph. It's only supported in eager mode.
        if ops.executing_eagerly_outside_functions() and getattr(
                strategy, "_enable_packed_variable_in_eager_mode", False):
            name = "%s/packed/" % self._common_name
            self._packed_var = packed.PackedDistributedVariable(values,
                                                                name=name)
        else:
            self._packed_var = None

        # tf.keras keeps track of variables initialized using this attribute. When
        # tf.keras gets the default session, it initializes all uninitialized vars.
        # We need to make _keras_initialized a member of DistributedVariable because
        # without this it will use `__getattr__` which will delegate to a component
        # variable.
        self._keras_initialized = False
        # Typically, a `DistributedVariable`'s initializer is composed of the
        # initializers of the components variables. However, in some cases, such as
        # when restoring from a checkpoint, we may set the _initializer_op
        # property on the entire `DistributedVariable`.
        self._initializer_op = None
Beispiel #2
0
    def testPackedVariable(self):
        with ops.device('/cpu:0'):
            v0 = resource_variable_ops.ResourceVariable(1.0, name='var0')
        with ops.device('/cpu:1'):
            v1 = resource_variable_ops.ResourceVariable(2.0, name='var1')

        packed_var = packed_distributed_variable.PackedDistributedVariable(
            [v0, v1])
        self.assertTrue(packed_var.handle.is_packed)
        self.assertTrue(packed_var.is_initialized)

        with ops.device('/cpu:0'):
            self.assertAllEqual(packed_var.get_var_on_current_device(), v0)
            val0 = packed_var.assign(2.0).assign_add(1.0)
            self.assertAllEqual(val0, 3.0)

        with ops.device('/cpu:1'):
            self.assertAllEqual(packed_var.get_var_on_current_device(), v1)
            val0 = packed_var.assign(2.0).assign_add(1.0)
            self.assertAllEqual(val0, 3.0)

        @def_function.function
        def update_var():
            with ops.device('/cpu:0'):
                packed_var.assign_add(3.0).assign_sub(1.0)
                read0 = packed_var.value()
            with ops.device('/cpu:1'):
                packed_var.assign_sub(4.0).assign_sub(2.0)
                read1 = packed_var.value()

            return read0, read1

        self.assertAllEqual(update_var(), (5.0, -3.0))
Beispiel #3
0
    def testPackedVarAndDevice(self):
        device0 = device_util.canonicalize('/cpu:0')
        device1 = device_util.canonicalize('/cpu:1')

        with ops.device(device0):
            v0 = resource_variable_ops.ResourceVariable(1.0)
        with ops.device(device1):
            v1 = resource_variable_ops.ResourceVariable(2.0)

        packed_var = packed_distributed_variable.PackedDistributedVariable(
            [v0, v1])

        packed_var0 = packed_distributed_variable.PackedVarAndDevice(
            packed_var, device0)
        self.assertTrue(packed_var0.handle.is_packed)
        self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0)

        packed_var1 = packed_distributed_variable.PackedVarAndDevice(
            packed_var, device1)
        self.assertAllEqual(packed_var1.assign(3.0), 3.0)

        @def_function.function
        def func():
            var0 = packed_distributed_variable.PackedVarAndDevice(
                packed_var, device0)
            var0.assign_add(3.0)
            var1 = packed_distributed_variable.PackedVarAndDevice(
                packed_var, device1)
            return var0.value(), math_ops.add(var1, 2.0)

        self.assertAllEqual(func(), (4.0, 5.0))
Beispiel #4
0
  def testNoGarbage(self):
    device0 = device_util.canonicalize('/cpu:0')
    device1 = device_util.canonicalize('/cpu:1')

    with ops.device(device0):
      v0 = resource_variable_ops.ResourceVariable(1.0)
    with ops.device(device1):
      v1 = resource_variable_ops.ResourceVariable(2.0)

    packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1])
    # This needs a workaround to avoid creating reference cycles if the
    # attribute doesn't exist.
    hasattr(packed_var.on_device('/cpu:0'), 'nonexist')