예제 #1
0
 def model_fn():
     replica_id = self.evaluate(_replica_id())
     v_sum = variable_scope.variable(
         1.0,
         synchronization=variable_scope.VariableSynchronization.ON_READ,
         aggregation=variable_scope.VariableAggregation.SUM)
     v_mean = variable_scope.variable(
         4.0,
         synchronization=variable_scope.VariableSynchronization.ON_READ,
         aggregation=variable_scope.VariableAggregation.MEAN)
     self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
     self.assertTrue(distribute_utils.is_sync_on_read(v_mean))
     updates = [
         v_sum.assign_add(2.0 + replica_id),
         v_mean.assign(6.0 * replica_id)
     ]
     all_v_sum[replica_id] = v_sum
     all_v_mean[replica_id] = v_mean
     c_sum = v_sum._get()
     c_mean = v_mean._get()
     components_sum[replica_id] = c_sum
     components_mean[replica_id] = c_mean
     self.assertIsNot(v_sum, c_sum)
     self.assertIsNot(v_mean, c_mean)
     return updates, v_sum, v_mean, c_sum, c_mean
예제 #2
0
 def read_var(self, replica_local_var):
   """Read the aggregate value of a replica-local variable."""
   # pylint: disable=protected-access
   if distribute_utils.is_sync_on_read(replica_local_var):
     return replica_local_var._get_cross_replica()
   assert distribute_utils.is_mirrored(replica_local_var)
   return array_ops.identity(replica_local_var._get())
예제 #3
0
 def model_fn():
     v_sum = variable_scope.variable(
         1.0,
         synchronization=variable_scope.VariableSynchronization.ON_READ,
         aggregation=variable_scope.VariableAggregation.SUM)
     self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
     return v_sum
예제 #4
0
    def testWithVariableAndVariableScope(self, distribution):
        def model_fn():
            v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
            with variable_scope.variable_scope("common"):
                v1 = variable_scope.variable(1.0, name="var1")
                # This will pause the current thread, and execute the other thread.
                ds_context.get_replica_context().merge_call(lambda _: _)
                v2 = variable_scope.variable(
                    1.0,
                    name="var2",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_READ,
                    aggregation=variable_scope.VariableAggregation.SUM)
                v3 = variable_scope.variable(
                    1.0,
                    name="var3",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variable_scope.VariableAggregation.MEAN)

            return v0, v1, v2, v3

        with distribution.scope():
            v = variable_scope.variable(1.0, name="var-main0")
            self.assertEqual("var-main0:0", v.name)

            result = distribution.extended.call_for_each_replica(model_fn)
            self.assertEqual(4, len(result))
            v0, v1, v2, v3 = result
            self.assertTrue(distribute_utils.is_mirrored(v0))
            self.assertEqual("var0:0", v0.name)
            self.assertTrue(distribute_utils.is_mirrored(v1))
            self.assertEqual("common/var1:0", v1.name)
            self.assertTrue(distribute_utils.is_sync_on_read(v2))
            self.assertEqual("common/var2:0", v2.name)
            self.assertEqual(variable_scope.VariableAggregation.SUM,
                             v2.aggregation)
            self.assertTrue(distribute_utils.is_mirrored(v3))
            self.assertEqual("common/var3:0", v3.name)
            self.assertEqual(variable_scope.VariableAggregation.MEAN,
                             v3.aggregation)