def worker_fn(): with distribute_utils.cache_variable_reads(): def replica_fn(): t = v.read_value() # Reads value 1.0 v.assign(constant_op.constant(5.0)) # v changes to 5.0 t = v.read_value() # should return 1.0 return t # Should be 1.0 instead of 5.0 return self.strategy.run(replica_fn)
def testVariableCaching(self): self.assertFalse(distribution_strategy_context.in_cross_replica_context()) with self.strategy.scope(): self.assertTrue(distribution_strategy_context.in_cross_replica_context()) v = variables.Variable( initial_value=1., aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) # Test read value inside caching scope with distribute_utils.cache_variable_reads(): v.read_value() # Reads value 1.0 v.assign(constant_op.constant(5.0)) # v changes to 5.0 self.assertEqual(v.read_value(), 1.0) # should be cached 1.0 value. # Reset v to 2.0 v.assign(2.0) # Test convert to tensor value inside caching scope with distribute_utils.cache_variable_reads(): t = v * 3.0 self.assertEqual(t, 6.0) v.assign(3.0) t1 = v * 3.0 self.assertEqual(t1, 6.0) # should be cached 2.0 * 3.0 value. # Reset v to 1.0 v.assign(1.0) # Verify caching scope inside tf.function @def_function.function def worker_fn(): with distribute_utils.cache_variable_reads(): def replica_fn(): t = v.read_value() # Reads value 1.0 v.assign(constant_op.constant(5.0)) # v changes to 5.0 t = v.read_value() # should return 1.0 return t # Should be 1.0 instead of 5.0 return self.strategy.run(replica_fn) result = self.coordinator.schedule(worker_fn) result = result.fetch() expected_result = 1. self.assertEqual(result, expected_result) # Verify that v.read_value works as expected outside of scope. v.assign(4.0) self.assertEqual(v.read_value(), 4.0) v.assign(constant_op.constant(2.0)) # v changes to 2.0 # Check with scope outside of tf function and check that cache is reset @def_function.function def worker_fn1(): def replica_fn(): t = v.read_value() # Reads value 2.0 ==> Should be cached v.assign(constant_op.constant(5.0)) # v changes to 5.0 t = v.read_value() # should return cached value 2.0 return t # Should be 2.0 instead of 5.0 return self.strategy.run(replica_fn) with distribute_utils.cache_variable_reads(): result = self.coordinator.schedule(worker_fn1) result = result.fetch() expected_result = 2. self.assertEqual(result, expected_result) # Verify scope nesting is not permitted. with self.assertRaises(ValueError): with distribute_utils.cache_variable_reads(): with distribute_utils.cache_variable_reads(): v.read_value()