def test_invalid_wrapped_variable(self, distribute): with get_distribute_scope(distribute): # Wrap a non-variable with self.assertRaisesRegexp(ValueError, 'variable must be of type'): x = constant_op.constant([1.], dtype=dtypes.float32) autocast_variable.create_autocast_variable(x) # Wrap a non-floating point variable with self.assertRaisesRegexp(ValueError, 'variable must be a floating point'): x = get_var(1, dtypes.int32) autocast_variable.create_autocast_variable(x)
def test_operator_overloads(self, distribute): with get_distribute_scope(distribute): for read_dtype in (dtypes.float32, dtypes.float16): x = get_var(7., dtypes.float32) x = autocast_variable.create_autocast_variable(x) with ops.get_default_graph()._enable_auto_casting_variables( read_dtype): self.evaluate(x.initializer) self.assertAlmostEqual(8, self.evaluate(x + 1)) self.assertAlmostEqual(10, self.evaluate(3 + x)) self.assertAlmostEqual(14, self.evaluate(x + x)) self.assertAlmostEqual(5, self.evaluate(x - 2)) self.assertAlmostEqual(6, self.evaluate(13 - x)) self.assertAlmostEqual(0, self.evaluate(x - x)) self.assertAlmostEqual(14, self.evaluate(x * 2)) self.assertAlmostEqual(21, self.evaluate(3 * x)) self.assertAlmostEqual(49, self.evaluate(x * x)) self.assertAlmostEqual(3.5, self.evaluate(x / 2)) self.assertAlmostEqual(1.5, self.evaluate(10.5 / x)) self.assertAlmostEqual(3, self.evaluate(x // 2)) self.assertAlmostEqual(2, self.evaluate(15 // x)) if read_dtype == dtypes.float32: # The "mod" operator does not support float16 self.assertAlmostEqual(1, self.evaluate(x % 2)) self.assertAlmostEqual(2, self.evaluate(16 % x)) self.assertTrue(self.evaluate(x < 12)) self.assertTrue(self.evaluate(x <= 12)) self.assertFalse(self.evaluate(x > 12)) self.assertFalse(self.evaluate(x >= 12)) self.assertFalse(self.evaluate(12 < x)) self.assertFalse(self.evaluate(12 <= x)) self.assertTrue(self.evaluate(12 > x)) self.assertTrue(self.evaluate(12 >= x)) self.assertAlmostEqual(343, self.evaluate(pow(x, 3)), places=4) self.assertAlmostEqual(128, self.evaluate(pow(2, x)), places=4) self.assertAlmostEqual(-7, self.evaluate(-x)) self.assertAlmostEqual(7, self.evaluate(abs(x))) x = get_var([7, 8, 9], dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) self.assertEqual(self.evaluate(x[1]), 8) if tf2.enabled() and context.executing_eagerly(): self.assertAllEqual(x == [7., 8., 10.], [True, True, False]) self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
def test_assign_stays_in_true_dtype(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) # small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not # in fp32 small_val = np.finfo('float16').eps / 2 small_tensor = constant_op.constant(small_val, dtype=dtypes.float32) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): # Variable should be increased, despite it appearing to be the same # float16 value. self.assertEqual(1. + small_val, self.evaluate(x.assign(1. + small_tensor))) self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1. + small_val, self.evaluate(x.value())) self.evaluate(x.assign(1.)) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(1. + small_val, self.evaluate(x.assign_add(small_tensor))) self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1. + small_val, self.evaluate(x.value()))
def test_repr(self): # We do not test with DistributionStrategy because we do not want to rely on # the exact __repr__ output of a DistributedVariable. x = get_var(1., dtypes.float32, name='x') x = autocast_variable.create_autocast_variable(x) if context.executing_eagerly(): self.assertStartsWith( repr(x), "<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, " "numpy=") with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertStartsWith( repr(x), "<AutoCastVariable 'x:0' shape=() dtype=float16 " "true_dtype=float32, numpy=") else: self.assertEqual( repr(x), "<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32>" ) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual( repr(x), "<AutoCastVariable 'x:0' shape=() dtype=float16 true_dtype=float32>" )
def test_multiple_source_types(self, loss_scale, strategy_fn, use_tf_function): loss_scale = loss_scale(32) strategy = strategy_fn() with strategy.scope(): x1 = variables.Variable(1.0) # Distributed variable x2 = variables.Variable([1.0, 2.0]) # Distributed non-scalar variable # Distributed AutoCastVariable x3 = autocast_variable.create_autocast_variable( variables.Variable(2.0)) x4 = variables.Variable(2.0) # Non-distributed variable x5 = constant_op.constant(2.0) # Tensor def run_fn(): with lsgt.LossScaleGradientTape(loss_scale) as g: g.watch(x5) y = x1 * x2 * x3 * x4 * x5 return g.gradient(y, [x1, x2, x3, x4, x5]) x1g, x2g, x3g, x4g, x5g = self._run_with_strategy( run_fn, strategy, use_tf_function) self.assertEqual(loss_scale(), 32) for dy_dx1 in x1g: self.assertEqual(dy_dx1, 24.0) for dy_dx2 in x2g: self.assertAllEqual(dy_dx2, [8.0, 8.0]) for dy_dx3 in x3g: self.assertEqual(dy_dx3, 12.0) for dy_dx4 in x4g: self.assertEqual(dy_dx4, 12.0) for dy_dx5 in x5g: self.assertEqual(dy_dx5, 12.0)
def test_read(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) # outside of auto cast scope. self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.value().dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(array_ops.identity(x).dtype, dtypes.float32) # within auto cast scope of different dtype with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.value().dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) self.assertEqual(array_ops.identity(x).dtype, dtypes.float16) # within auto cast scope of same dtype with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float32): self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.value().dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
def test_repr_distributed(self): with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertRegexpMatches( repr(x).replace('\n', ' '), '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 ' 'inner_variable=MirroredVariable.*>')
def test_repr_distributed(self): with get_distribute_scope(distribute=True): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertRegexpMatches( repr(x).replace('\n', ' '), '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 ' 'inner_variable=MirroredVariable.*>')
def test_assign(self, distribute): with get_distribute_scope(distribute): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) # outside of auto cast scope. v1 = constant_op.constant(3.14, dtype=dtypes.float32) v2 = constant_op.constant(3.14, dtype=dtypes.float16) def run_and_check(): # Assign float32 values self.assertAllClose(3.14, self.evaluate(x.assign(v1))) self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(v1))) self.assertAllClose(3.14, self.evaluate(x.assign_sub(v1))) # Attempt to assign float16 values with self.assertRaisesRegexp( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign(v2)) with self.assertRaisesRegexp( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign_add(v2)) with self.assertRaisesRegexp( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign_sub(v2)) # Assign Python floats self.assertAllClose(0., self.evaluate(x.assign(0.))) self.assertAllClose(3.14, self.evaluate(x.assign(3.14))) self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(3.14))) self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14))) # Use the tf.assign functions instead of the var.assign methods. self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.))) self.assertAllClose(3.14, self.evaluate(state_ops.assign(x, 3.14))) self.assertAllClose( 3.14 * 2, self.evaluate(state_ops.assign_add(x, 3.14))) self.assertAllClose( 3.14, self.evaluate(state_ops.assign_sub(x, 3.14))) run_and_check() # reset x self.evaluate(x.assign(0.)) # within auto cast scope. with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): # assign still expect float32 value even if in float16 scope run_and_check()
def test_assign_op(self, distribution): with distribution.scope(): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) @def_function.function def func(): self.assertIsNotNone(x.assign(1.0).op) self.assertIsNotNone(x.assign_add(1.0).op) self.assertIsNotNone(x.assign_sub(1.0).op) func()
def test_checkpoint(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) self.evaluate(x.assign(123.)) checkpoint = trackable_utils.Checkpoint(x=x) prefix = os.path.join(self.get_temp_dir(), 'ckpt') save_path = checkpoint.save(prefix) self.evaluate(x.assign(234.)) checkpoint.restore(save_path).assert_consumed().run_restore_ops() self.assertEqual(self.evaluate(x), 123.)
def test_sparse_reads(self): x = get_var([1., 2], dtypes.float32) # DistributedVariables do not support sparse_read or gather_nd, so we pass # distribute=False x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) self.assertEqual(x.sparse_read([0]).dtype, dtypes.float32) self.assertEqual(x.gather_nd([0]).dtype, dtypes.float32) with autocast_variable.enable_auto_cast_variables(dtypes.float16): self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16) self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
def test_read_nested_scopes(self, distribution): with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) with autocast_variable.enable_auto_cast_variables(dtypes.float16): self.assertEqual(x.read_value().dtype, dtypes.float16) with autocast_variable.enable_auto_cast_variables( dtypes.float32): self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float16)
def test_dtype_is_not_string(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertEqual(x.dtype, dtypes.float32) self.assertIsInstance(x.dtype, dtypes.DType) self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType) with ops.get_default_graph()._enable_auto_casting_variables('float16'): self.assertEqual(x.dtype, dtypes.float16) self.assertIsInstance(x.dtype, dtypes.DType) self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType)
def test_dtype_is_not_string(self, distribution): with distribution.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.assertEqual(x.dtype, dtypes.float32) self.assertIsInstance(x.dtype, dtypes.DType) self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType) dtype = dtypes.float16 with autocast_variable.enable_auto_cast_variables(dtype): self.assertEqual(x.dtype, dtypes.float32) self.assertIsInstance(x.dtype, dtypes.DType) self.assertEqual(x.true_dtype, dtypes.float32) self.assertIsInstance(x.true_dtype, dtypes.DType)
def test_assign_tf_function(self, distribution): if not context.executing_eagerly(): self.skipTest('Test is not compatible with graph mode') with distribution.scope(): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) @def_function.function def run_assign(): return x.assign(1.).assign_add(3.).assign_add(3.).assign_sub( 2.) with autocast_variable.enable_auto_cast_variables(dtypes.float16): self.assertAllClose(5., self.evaluate(run_assign()))
def test_repr_distributed(self): strategy = mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']) with strategy.scope(): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) use_policy = getattr(strategy.extended, '_use_policy', False) if use_policy: self.assertRegex( repr(x).replace('\n', ' '), '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 ' 'inner_variable=DistributedVariable.*>') else: self.assertRegex( repr(x).replace('\n', ' '), '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 ' 'inner_variable=MirroredVariable.*>')
def test_tf_function_control_dependencies(self, distribution): if not context.executing_eagerly(): self.skipTest('Test is not compatible with graph mode') with distribution.scope(): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) @def_function.function def func(): update = x.assign_add(1.) with ops.control_dependencies([update]): x.assign_add(1.) func() self.assertAllClose(2., self.evaluate(x))
def test_optimizer(self, optimizer_class): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) opt = optimizer_class(1.) @def_function.function def f(): opt.minimize(lambda: x + 1., var_list=[x]) if context.executing_eagerly(): f() else: op = f() # pylint: disable=assignment-from-no-return self.evaluate(variables.global_variables_initializer()) self.evaluate(op) self.assertEqual(self.evaluate(x), 0)
def test_read_nested_scopes(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float32): self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16)
def test_thread_local_autocast_dtype(self): x = get_var(1., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) with autocast_variable.enable_auto_cast_variables(dtypes.float16): self.assertEqual(array_ops.identity(x).dtype, dtypes.float16) # New threads should not see the modified value of the autocast dtype. var_dtype = None def f(): nonlocal var_dtype var_dtype = x._cast_dtype thread = threading.Thread(target=f) thread.start() thread.join() self.assertEqual(var_dtype, dtypes.float32)
def test_method_delegations(self, distribution): # Test AutoCastVariable correctly delegates Variable methods to the # underlying variable. with self.test_session(), distribution.scope(): for read_dtype in (dtypes.float32, dtypes.float16): if ds_context.has_strategy(): # MirroredVariable.assign will (incorrectly) return a Mirrored value # instead of a MirroredVariable. So we cannot properly wrap it in an # AutoCastVariable. evaluate = self.evaluate else: def evaluate(var): self.assertIsInstance( var, autocast_variable.AutoCastVariable) self.assertEqual( array_ops.identity(var).dtype, read_dtype) # pylint: disable=cell-var-from-loop return self.evaluate(var) x = get_var(7., dtypes.float32) x = autocast_variable.create_autocast_variable(x) with autocast_variable.enable_auto_cast_variables(read_dtype): self.evaluate(x.initializer) self.assertEqual(self.evaluate(x.value()), 7) self.assertEqual(self.evaluate(x.read_value()), 7) self.assertTrue(x.trainable) self.assertEqual(x.synchronization, x._variable.synchronization) self.assertEqual(x.aggregation, x._variable.aggregation) self.assertEqual(self.evaluate(x.initialized_value()), 7) if not context.executing_eagerly(): if not ds_context.has_strategy(): # These functions are not supported for DistributedVariables x.load(9) self.assertEqual(x.eval(), 9) self.assertEqual(self.evaluate(x.initial_value), 7) self.assertEqual(x.op, x._variable.op) self.assertEqual(x.graph, x._variable.graph) if not ds_context.has_strategy(): # These attributes are not supported for DistributedVariables self.assertIsNone(x.constraint) self.assertEqual(x.initializer, x._variable.initializer) self.assertEqual(evaluate(x.assign(8)), 8) self.assertEqual(evaluate(x.assign_add(2)), 10) self.assertEqual(evaluate(x.assign_sub(3)), 7) self.assertEqual(x.name, x._variable.name) self.assertEqual(x.device, x._variable.device) self.assertEqual(x.shape, ()) self.assertEqual(x.get_shape(), ()) if not ds_context.has_strategy(): # Test scatter_* methods. These are not supported for # DistributedVariables x = get_var([7, 8], dtypes.float32) x = autocast_variable.create_autocast_variable(x) with autocast_variable.enable_auto_cast_variables( read_dtype): self.evaluate(x.initializer) self.assertAllEqual(self.evaluate(x.value()), [7, 8]) def slices(val, index): return indexed_slices.IndexedSlices( values=constant_op.constant( val, dtype=dtypes.float32), indices=constant_op.constant( index, dtype=dtypes.int32), dense_shape=constant_op.constant( [2], dtype=dtypes.int32)) self.assertAllEqual( evaluate(x.scatter_sub(slices(1., 0))), [6, 8]) self.assertAllEqual( evaluate(x.scatter_add(slices(1., 0))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_max(slices(9., 1))), [7, 9]) self.assertAllEqual( evaluate(x.scatter_min(slices(8., 1))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_mul(slices(2., 1))), [7, 16]) self.assertAllEqual( evaluate(x.scatter_div(slices(2., 1))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_update(slices(4., 1))), [7, 4]) self.assertAllEqual( evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])), [6, 2]) self.assertAllEqual( evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])), [7, 4]) self.assertAllEqual( evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
def test_method_delegations(self, distribute): # Test AutoCastVariable correctly delegates Variable methods to the # underlying variable. with get_distribute_scope(distribute): evaluate = self.evaluate for read_dtype in (dtypes.float32, dtypes.float16): x = get_var(7., dtypes.float32) x = autocast_variable.create_autocast_variable(x) with ops.get_default_graph()._enable_auto_casting_variables( read_dtype): evaluate(x.initializer) self.assertEqual(evaluate(x.value()), 7) self.assertEqual(evaluate(x.read_value()), 7) self.assertTrue(x.trainable) self.assertEqual(x.synchronization, x._variable.synchronization) self.assertEqual(x.aggregation, x._variable.aggregation) self.assertEqual(evaluate(x.initialized_value()), 7) if not context.executing_eagerly(): if not distribute: # These functions are not supported for DistributedVariables x.load(9) self.assertEqual(x.eval(), 9) self.assertEqual(evaluate(x.initial_value), 7) self.assertEqual(x.op, x._variable.op) self.assertEqual(x.graph, x._variable.graph) if not distribute: # These attributes are not supported for DistributedVariables self.assertIsNone(x.constraint) self.assertEqual(x.initializer, x._variable.initializer) self.assertEqual(evaluate(x.assign(8)), 8) self.assertEqual(evaluate(x.assign_add(2)), 10) self.assertEqual(evaluate(x.assign_sub(3)), 7) self.assertEqual(x.name, x._variable.name) self.assertEqual(x.device, x._variable.device) self.assertEqual(x.shape, ()) self.assertEqual(x.get_shape(), ()) if not distribute: # Test scatter_* methods. These are not supported for # DistributedVariables x = get_var([7, 8], dtypes.float32) x = autocast_variable.create_autocast_variable(x) with ops.get_default_graph( )._enable_auto_casting_variables(read_dtype): evaluate(x.initializer) self.assertAllEqual(evaluate(x.value()), [7, 8]) def slices(val, index): return indexed_slices.IndexedSlices( values=constant_op.constant( val, dtype=dtypes.float32), indices=constant_op.constant( index, dtype=dtypes.int32), dense_shape=constant_op.constant( [2], dtype=dtypes.int32)) self.assertAllEqual( evaluate(x.scatter_sub(slices(1., 0))), [6, 8]) self.assertAllEqual( evaluate(x.scatter_add(slices(1., 0))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_max(slices(9., 1))), [7, 9]) self.assertAllEqual( evaluate(x.scatter_min(slices(8., 1))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_mul(slices(2., 1))), [7, 16]) self.assertAllEqual( evaluate(x.scatter_div(slices(2., 1))), [7, 8]) self.assertAllEqual( evaluate(x.scatter_update(slices(4., 1))), [7, 4]) self.assertAllEqual( evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])), [6, 2]) self.assertAllEqual( evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])), [7, 4]) self.assertAllEqual( evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
def test_assign(self, distribution): with distribution.scope(): x = get_var(0., dtypes.float32) x = autocast_variable.create_autocast_variable(x) self.evaluate(x.initializer) # outside of auto cast scope. v1 = constant_op.constant(3., dtype=dtypes.float32) v2 = constant_op.constant(3., dtype=dtypes.float16) def run_and_check(): # Assign float32 values self.assertAllClose(3., self.evaluate(x.assign(v1))) self.assertAllClose(3. * 2, self.evaluate(x.assign_add(v1))) self.assertAllClose(3., self.evaluate(x.assign_sub(v1))) # Attempt to assign float16 values with self.assertRaisesRegex( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign(v2)) with self.assertRaisesRegex( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign_add(v2)) with self.assertRaisesRegex( ValueError, 'conversion requested dtype float32 for Tensor with dtype float16' ): self.evaluate(x.assign_sub(v2)) # Assign Python floats self.assertAllClose(0., self.evaluate(x.assign(0.))) self.assertAllClose(3., self.evaluate(x.assign(3.))) self.assertAllClose(3. * 2, self.evaluate(x.assign_add(3.))) self.assertAllClose(3., self.evaluate(x.assign_sub(3.))) # Assign multiple times # This currently doesn't work in graph mode if a strategy is used if not ds_context.has_strategy() or context.executing_eagerly( ): assign = x.assign(1.) self.assertAllClose(1., self.evaluate(assign)) self.assertAllClose(0., self.evaluate(assign.assign(0.))) assign_add = x.assign_add(3.) self.assertAllClose(3., self.evaluate(assign_add)) self.assertAllClose( 3. * 3, self.evaluate(x.assign_add(3.).assign_add(3.))) self.assertAllClose(3. * 3, x) assign_sub = x.assign_sub(3.) self.assertAllClose(3. * 2, self.evaluate(assign_sub)) self.assertAllClose( 0., self.evaluate(x.assign_sub(3.).assign_sub(3.))) # Assign with read_value=False self.assertIsNone(self.evaluate(x.assign(1., read_value=False))) self.assertAllClose(1., self.evaluate(x)) self.assertIsNone( self.evaluate(x.assign_add(2., read_value=False))) self.assertAllClose(3., self.evaluate(x)) self.assertIsNone( self.evaluate(x.assign_sub(3., read_value=False))) self.assertAllClose(0., self.evaluate(x)) # Use the tf.assign functions instead of the var.assign methods. self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.))) self.assertAllClose(3., self.evaluate(state_ops.assign(x, 3.))) self.assertAllClose(3. * 2, self.evaluate(state_ops.assign_add(x, 3.))) self.assertAllClose(3., self.evaluate(state_ops.assign_sub(x, 3.))) run_and_check() # reset x self.evaluate(x.assign(0.)) # within auto cast scope. with autocast_variable.enable_auto_cast_variables(dtypes.float16): # assign still expect float32 value even if in float16 scope run_and_check()