def model_fn(): replica_id = self.evaluate(_replica_id()) v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v_mean = variable_scope.variable( 4.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) self.assertTrue(distribute_utils.is_sync_on_read(v_mean)) updates = [ v_sum.assign_add(2.0 + replica_id), v_mean.assign(6.0 * replica_id) ] all_v_sum[replica_id] = v_sum all_v_mean[replica_id] = v_mean c_sum = v_sum._get() c_mean = v_mean._get() components_sum[replica_id] = c_sum components_mean[replica_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean
def read_var(self, replica_local_var): """Read the aggregate value of a replica-local variable.""" # pylint: disable=protected-access if distribute_utils.is_sync_on_read(replica_local_var): return replica_local_var._get_cross_replica() assert distribute_utils.is_mirrored(replica_local_var) return array_ops.identity(replica_local_var._get())
def model_fn(): v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) return v_sum
def testWithVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", synchronization=variable_scope.VariableSynchronization. ON_READ, aggregation=variable_scope.VariableAggregation.SUM) v3 = variable_scope.variable( 1.0, name="var3", synchronization=variable_scope.VariableSynchronization. ON_WRITE, aggregation=variable_scope.VariableAggregation.MEAN) return v0, v1, v2, v3 with distribution.scope(): v = variable_scope.variable(1.0, name="var-main0") self.assertEqual("var-main0:0", v.name) result = distribution.extended.call_for_each_replica(model_fn) self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertTrue(distribute_utils.is_mirrored(v0)) self.assertEqual("var0:0", v0.name) self.assertTrue(distribute_utils.is_mirrored(v1)) self.assertEqual("common/var1:0", v1.name) self.assertTrue(distribute_utils.is_sync_on_read(v2)) self.assertEqual("common/var2:0", v2.name) self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertTrue(distribute_utils.is_mirrored(v3)) self.assertEqual("common/var3:0", v3.name) self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation)