예제 #1
0
    def testWithLayers(self, distribution):
        def model_fn(features):

            layer1 = core.Dense(1)
            layer1(features)
            layer2 = core.Dense(1)
            layer2(features)
            # We rely on names and orders to make sure replica references the same
            # MirroredVariable. Uniquifying names may involve global states,
            # merge_call switches threads so we need to test things work after
            # merge_call.
            ds_context.get_replica_context().merge_call(lambda _: _)
            layer3 = core.Dense(1)
            layer3(features)
            return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias),
                    (layer3.kernel, layer3.bias)]

        iterator = distribution.make_input_fn_iterator(
            lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
        self.evaluate(iterator.initializer)
        features = iterator.get_next()

        with distribution.scope():
            result = distribution.extended.call_for_each_replica(
                model_fn, args=(features, ))
            for kernel, bias in result:
                self.assertTrue(distribute_utils.is_mirrored(kernel))
                self.assertAllDifferent(
                    distribution.experimental_local_results(kernel))
                self.assertTrue(distribute_utils.is_mirrored(bias))
                self.assertAllDifferent(
                    distribution.experimental_local_results(kernel))
예제 #2
0
    def _reduce_to(self, reduce_op, value, destinations, options):
        if (distribute_utils.is_mirrored(value)
                and reduce_op == reduce_util.ReduceOp.MEAN):
            return value
        assert not distribute_utils.is_mirrored(value)

        def get_values(value):
            if not isinstance(value, values.DistributedValues):
                # This function handles reducing values that are not PerReplica or
                # Mirrored values. For example, the same value could be present on all
                # replicas in which case `value` would be a single value or value could
                # be 0.
                return cross_device_ops_lib.reduce_non_distributed_value(
                    reduce_op, value, destinations, self._num_replicas_in_sync)
            if self._use_merge_call() and self._collective_ops_in_use and ((
                    not cross_device_ops_lib._devices_match(
                        value, destinations) or  # pylint: disable=protected-access
                    any("cpu" in d.lower() for d in
                        cross_device_ops_lib.get_devices_from(destinations)))):
                return cross_device_ops_lib.ReductionToOneDevice().reduce(
                    reduce_op, value, destinations)
            return self._get_cross_device_ops(value).reduce(
                reduce_op,
                value,
                destinations=destinations,
                options=self._communication_options.merge(options))

        return nest.map_structure(get_values, value)
예제 #3
0
 def read_var(self, replica_local_var):
   """Read the aggregate value of a replica-local variable."""
   # pylint: disable=protected-access
   if distribute_utils.is_sync_on_read(replica_local_var):
     return replica_local_var._get_cross_replica()
   assert distribute_utils.is_mirrored(replica_local_var)
   return array_ops.identity(replica_local_var._get())
예제 #4
0
 def _reduce_to(self, reduce_op, value, destinations, options):
   if (distribute_utils.is_mirrored(value) and
       reduce_op == reduce_util.ReduceOp.MEAN):
     return value
   assert not distribute_utils.is_mirrored(value)
   if not isinstance(value, values.DistributedValues):
     # This function handles reducing values that are not PerReplica or
     # Mirrored values. For example, the same value could be present on all
     # replicas in which case `value` would be a single value or value could
     # be 0.
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, value, destinations, self._num_replicas_in_sync)
   return self._get_cross_device_ops(value).reduce(
       reduce_op,
       value,
       destinations=destinations,
       options=self._communication_options.merge(options))
 def _test_mv_properties(self, var, name, strategy):
   self.assertTrue(distribute_utils.is_mirrored(var))
   self.assertEqual(name, var.name)
   self.assertIs(strategy, var.distribute_strategy)
   for i, d in enumerate(var._devices):
     self.assertEqual(d, strategy.experimental_local_results(var)[i].device)
     self.assertIs(
         strategy,
         strategy.experimental_local_results(var)[i]._distribute_strategy)  # pylint: disable=protected-access
예제 #6
0
    def testWithVariableAndVariableScope(self, distribution):
        def model_fn():
            v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
            with variable_scope.variable_scope("common"):
                v1 = variable_scope.variable(1.0, name="var1")
                # This will pause the current thread, and execute the other thread.
                ds_context.get_replica_context().merge_call(lambda _: _)
                v2 = variable_scope.variable(
                    1.0,
                    name="var2",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_READ,
                    aggregation=variable_scope.VariableAggregation.SUM)
                v3 = variable_scope.variable(
                    1.0,
                    name="var3",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variable_scope.VariableAggregation.MEAN)

            return v0, v1, v2, v3

        with distribution.scope():
            v = variable_scope.variable(1.0, name="var-main0")
            self.assertEqual("var-main0:0", v.name)

            result = distribution.extended.call_for_each_replica(model_fn)
            self.assertEqual(4, len(result))
            v0, v1, v2, v3 = result
            self.assertTrue(distribute_utils.is_mirrored(v0))
            self.assertEqual("var0:0", v0.name)
            self.assertTrue(distribute_utils.is_mirrored(v1))
            self.assertEqual("common/var1:0", v1.name)
            self.assertTrue(distribute_utils.is_sync_on_read(v2))
            self.assertEqual("common/var2:0", v2.name)
            self.assertEqual(variable_scope.VariableAggregation.SUM,
                             v2.aggregation)
            self.assertTrue(distribute_utils.is_mirrored(v3))
            self.assertEqual("common/var3:0", v3.name)
            self.assertEqual(variable_scope.VariableAggregation.MEAN,
                             v3.aggregation)
예제 #7
0
    def testVariableWithSameCanonicalNameAcrossThreads(self, distribution):
        def model_fn():
            replica_id = self.evaluate(_replica_id())
            v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
            ds_context.get_replica_context().merge_call(lambda _: _)
            return v

        with distribution.scope():
            result = distribution.extended.call_for_each_replica(model_fn)
            self.assertTrue(distribute_utils.is_mirrored(result))
            # The resulting mirrored variable will use the name from the first device.
            self.assertEqual("foo_0:0", result.name)
예제 #8
0
    def testMultipleVariablesWithSameCanonicalName(self, distribution):
        def model_fn():
            vs = []
            vs.append(variable_scope.variable(1.0, name="foo/bar"))
            vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
            vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
            vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
            ds_context.get_replica_context().merge_call(lambda _: _)
            return vs

        with distribution.scope():
            result = distribution.extended.call_for_each_replica(model_fn)
            for v in result:
                self.assertTrue(distribute_utils.is_mirrored(v))
            self.assertEqual(4, len(result))
            self.assertEqual("foo/bar:0", result[0].name)
            self.assertEqual("foo_1/bar:0", result[1].name)
            self.assertEqual("foo_1/bar_1:0", result[2].name)
            self.assertEqual("foo/bar_1:0", result[3].name)