예제 #1
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update the variable
        # on each replica directly.
        updates = []
        values_and_devices = []
        packed_var = var._packed_variable  # pylint: disable=protected-access
        if packed_var is not None:
            for device in packed_var.devices:
                values_and_devices.append((packed_var, device))
        else:
            for value in var.values:
                values_and_devices.append((value, value.device))

        for i, value_and_device in enumerate(values_and_devices):
            value = value_and_device[0]
            device = value_and_device[1]
            name = "update_%d" % i
            with ops.device(device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(value,
                       *distribute_utils.select_replica_mirrored(i, args),
                       **distribute_utils.select_replica_mirrored(i, kwargs)))
        return distribute_utils.update_regroup(self, updates, group)
    def testWrapClass(self):
        # Normally a mirrored value would be the same across devices, but
        # for a test it is convenient to be able to tell the values apart.
        result = distribute_utils.regroup(
            (_nested_value("1"), _nested_value("2")), values.Mirrored)
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 3)
        self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
        self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)

        self.assertIsInstance(result[1], list)
        self.assertLen(result[1], 3)
        self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
        self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
        self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)

        # Also test that we can undo the merge using select_replica()
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica(1, result))
        # Values are marked as mirrored, so select_device_mirrored() is allowed.
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica_mirrored(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica_mirrored(1, result))
    def testNested(self):
        result = distribute_utils.regroup(
            (_nested_value("1"), _nested_value("2")))
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 3)
        self._is_per_replica(result[0], ["a1", "a2"])
        self._is_per_replica(result[2], ["h1", "h2"])

        self.assertIsInstance(result[1], list)
        self.assertLen(result[1], 3)
        self._is_per_replica(result[1][0], ["b1", "b2"])
        self._is_per_replica(result[1][2], ["g1", "g2"])

        self.assertIsInstance(result[1][1], dict)
        self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
        self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
        self._is_per_replica(result[1][1]["e"], ["f1", "f2"])

        # Also test that we can undo the merge using select_replica()
        self.assertEqual(_nested_value("1"),
                         distribute_utils.select_replica(0, result))
        self.assertEqual(_nested_value("2"),
                         distribute_utils.select_replica(1, result))
        # select_device_mirrored() should fail due to non-mirrored values
        with self.assertRaises(TypeError):
            distribute_utils.select_replica_mirrored(0, result)
        with self.assertRaises(TypeError):
            distribute_utils.select_replica_mirrored(1, result)
예제 #4
0
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, tuple)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = []
   for i, d in enumerate(colocate_with):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
       updates.append(
           fn(*distribute_utils.select_replica_mirrored(i, args),
              **distribute_utils.select_replica_mirrored(i, kwargs)))
   return distribute_utils.update_regroup(self, updates, group)
예제 #5
0
 def _update(self, var, fn, args, kwargs, group):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = []
   for i, v in enumerate(var.values):
     name = "update_%d" % i
     with ops.device(v.device), \
          distribute_lib.UpdateContext(i), \
          ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates.append(
           fn(v, *distribute_utils.select_replica_mirrored(i, args),
              **distribute_utils.select_replica_mirrored(i, kwargs)))
   return distribute_utils.update_regroup(self, updates, group)
예제 #6
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = []
        for i, v in enumerate(var.values):
            name = "update_%d" % i
            with ops.device(v.device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(v, *distribute_utils.select_replica_mirrored(i, args),
                       **distribute_utils.select_replica_mirrored(i, kwargs)))
        return distribute_utils.update_regroup(self, updates, group)