Example #1
0
  def test_invalid_wrapped_variable(self, distribute):
    with get_distribute_scope(distribute):
      # Wrap a non-variable
      with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
        x = constant_op.constant([1.], dtype=dtypes.float32)
        autocast_variable.create_autocast_variable(x)

      # Wrap a non-floating point variable
      with self.assertRaisesRegexp(ValueError,
                                   'variable must be a floating point'):
        x = get_var(1, dtypes.int32)
        autocast_variable.create_autocast_variable(x)
    def test_operator_overloads(self, distribute):
        with get_distribute_scope(distribute):
            for read_dtype in (dtypes.float32, dtypes.float16):
                x = get_var(7., dtypes.float32)
                x = autocast_variable.create_autocast_variable(x)
                with ops.get_default_graph()._enable_auto_casting_variables(
                        read_dtype):
                    self.evaluate(x.initializer)
                    self.assertAlmostEqual(8, self.evaluate(x + 1))
                    self.assertAlmostEqual(10, self.evaluate(3 + x))
                    self.assertAlmostEqual(14, self.evaluate(x + x))
                    self.assertAlmostEqual(5, self.evaluate(x - 2))
                    self.assertAlmostEqual(6, self.evaluate(13 - x))
                    self.assertAlmostEqual(0, self.evaluate(x - x))
                    self.assertAlmostEqual(14, self.evaluate(x * 2))
                    self.assertAlmostEqual(21, self.evaluate(3 * x))
                    self.assertAlmostEqual(49, self.evaluate(x * x))
                    self.assertAlmostEqual(3.5, self.evaluate(x / 2))
                    self.assertAlmostEqual(1.5, self.evaluate(10.5 / x))
                    self.assertAlmostEqual(3, self.evaluate(x // 2))
                    self.assertAlmostEqual(2, self.evaluate(15 // x))
                    if read_dtype == dtypes.float32:
                        # The "mod" operator does not support float16
                        self.assertAlmostEqual(1, self.evaluate(x % 2))
                        self.assertAlmostEqual(2, self.evaluate(16 % x))
                    self.assertTrue(self.evaluate(x < 12))
                    self.assertTrue(self.evaluate(x <= 12))
                    self.assertFalse(self.evaluate(x > 12))
                    self.assertFalse(self.evaluate(x >= 12))
                    self.assertFalse(self.evaluate(12 < x))
                    self.assertFalse(self.evaluate(12 <= x))
                    self.assertTrue(self.evaluate(12 > x))
                    self.assertTrue(self.evaluate(12 >= x))
                    self.assertAlmostEqual(343,
                                           self.evaluate(pow(x, 3)),
                                           places=4)
                    self.assertAlmostEqual(128,
                                           self.evaluate(pow(2, x)),
                                           places=4)
                    self.assertAlmostEqual(-7, self.evaluate(-x))
                    self.assertAlmostEqual(7, self.evaluate(abs(x)))

                    x = get_var([7, 8, 9], dtypes.float32)
                    x = autocast_variable.create_autocast_variable(x)
                    self.evaluate(x.initializer)
                    self.assertEqual(self.evaluate(x[1]), 8)
                    if tf2.enabled() and context.executing_eagerly():
                        self.assertAllEqual(x == [7., 8., 10.],
                                            [True, True, False])
                        self.assertAllEqual(x != [7., 8., 10.],
                                            [False, False, True])
    def test_assign_stays_in_true_dtype(self, distribute):
        with get_distribute_scope(distribute):
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)
            # small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not
            # in fp32
            small_val = np.finfo('float16').eps / 2
            small_tensor = constant_op.constant(small_val,
                                                dtype=dtypes.float32)
            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float16):
                # Variable should be increased, despite it appearing to be the same
                # float16 value.
                self.assertEqual(1. + small_val,
                                 self.evaluate(x.assign(1. + small_tensor)))
                self.assertEqual(1., self.evaluate(x.value()))
            self.assertEqual(1. + small_val, self.evaluate(x.value()))

            self.evaluate(x.assign(1.))
            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float16):
                self.assertEqual(1. + small_val,
                                 self.evaluate(x.assign_add(small_tensor)))
                self.assertEqual(1., self.evaluate(x.value()))
            self.assertEqual(1. + small_val, self.evaluate(x.value()))
 def test_repr(self):
     # We do not test with DistributionStrategy because we do not want to rely on
     # the exact __repr__ output of a DistributedVariable.
     x = get_var(1., dtypes.float32, name='x')
     x = autocast_variable.create_autocast_variable(x)
     if context.executing_eagerly():
         self.assertStartsWith(
             repr(x),
             "<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, "
             "numpy=")
         with ops.get_default_graph()._enable_auto_casting_variables(
                 dtypes.float16):
             self.assertStartsWith(
                 repr(x), "<AutoCastVariable 'x:0' shape=() dtype=float16 "
                 "true_dtype=float32, numpy=")
     else:
         self.assertEqual(
             repr(x),
             "<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32>"
         )
         with ops.get_default_graph()._enable_auto_casting_variables(
                 dtypes.float16):
             self.assertEqual(
                 repr(x),
                 "<AutoCastVariable 'x:0' shape=() dtype=float16 true_dtype=float32>"
             )
Example #5
0
    def test_multiple_source_types(self, loss_scale, strategy_fn,
                                   use_tf_function):
        loss_scale = loss_scale(32)
        strategy = strategy_fn()
        with strategy.scope():
            x1 = variables.Variable(1.0)  # Distributed variable
            x2 = variables.Variable([1.0,
                                     2.0])  # Distributed non-scalar variable
            # Distributed AutoCastVariable
            x3 = autocast_variable.create_autocast_variable(
                variables.Variable(2.0))
        x4 = variables.Variable(2.0)  # Non-distributed variable
        x5 = constant_op.constant(2.0)  # Tensor

        def run_fn():
            with lsgt.LossScaleGradientTape(loss_scale) as g:
                g.watch(x5)
                y = x1 * x2 * x3 * x4 * x5
            return g.gradient(y, [x1, x2, x3, x4, x5])

        x1g, x2g, x3g, x4g, x5g = self._run_with_strategy(
            run_fn, strategy, use_tf_function)
        self.assertEqual(loss_scale(), 32)
        for dy_dx1 in x1g:
            self.assertEqual(dy_dx1, 24.0)
        for dy_dx2 in x2g:
            self.assertAllEqual(dy_dx2, [8.0, 8.0])
        for dy_dx3 in x3g:
            self.assertEqual(dy_dx3, 12.0)
        for dy_dx4 in x4g:
            self.assertEqual(dy_dx4, 12.0)
        for dy_dx5 in x5g:
            self.assertEqual(dy_dx5, 12.0)
    def test_read(self, distribute):
        with get_distribute_scope(distribute):
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)

            # outside of auto cast scope.
            self.assertEqual(x.dtype, dtypes.float32)
            self.assertEqual(x.value().dtype, dtypes.float32)
            self.assertEqual(x.read_value().dtype, dtypes.float32)
            self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)

            # within auto cast scope of different dtype
            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float16):
                self.assertEqual(x.dtype, dtypes.float16)
                self.assertEqual(x.value().dtype, dtypes.float16)
                self.assertEqual(x.read_value().dtype, dtypes.float16)
                self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)

            # within auto cast scope of same dtype
            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float32):
                self.assertEqual(x.dtype, dtypes.float32)
                self.assertEqual(x.value().dtype, dtypes.float32)
                self.assertEqual(x.read_value().dtype, dtypes.float32)
                self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
 def test_repr_distributed(self):
     with mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2']).scope():
         x = get_var(1., dtypes.float32)
         x = autocast_variable.create_autocast_variable(x)
         self.assertRegexpMatches(
             repr(x).replace('\n', ' '),
             '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
             'inner_variable=MirroredVariable.*>')
 def test_repr_distributed(self):
     with get_distribute_scope(distribute=True):
         x = get_var(1., dtypes.float32)
         x = autocast_variable.create_autocast_variable(x)
         self.assertRegexpMatches(
             repr(x).replace('\n', ' '),
             '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
             'inner_variable=MirroredVariable.*>')
    def test_assign(self, distribute):
        with get_distribute_scope(distribute):
            x = get_var(0., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)

            # outside of auto cast scope.
            v1 = constant_op.constant(3.14, dtype=dtypes.float32)
            v2 = constant_op.constant(3.14, dtype=dtypes.float16)

            def run_and_check():
                # Assign float32 values
                self.assertAllClose(3.14, self.evaluate(x.assign(v1)))
                self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(v1)))
                self.assertAllClose(3.14, self.evaluate(x.assign_sub(v1)))

                # Attempt to assign float16 values
                with self.assertRaisesRegexp(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign(v2))
                with self.assertRaisesRegexp(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign_add(v2))
                with self.assertRaisesRegexp(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign_sub(v2))

                # Assign Python floats
                self.assertAllClose(0., self.evaluate(x.assign(0.)))
                self.assertAllClose(3.14, self.evaluate(x.assign(3.14)))
                self.assertAllClose(3.14 * 2,
                                    self.evaluate(x.assign_add(3.14)))
                self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14)))

                # Use the tf.assign functions instead of the var.assign methods.
                self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.)))
                self.assertAllClose(3.14,
                                    self.evaluate(state_ops.assign(x, 3.14)))
                self.assertAllClose(
                    3.14 * 2, self.evaluate(state_ops.assign_add(x, 3.14)))
                self.assertAllClose(
                    3.14, self.evaluate(state_ops.assign_sub(x, 3.14)))

            run_and_check()
            # reset x
            self.evaluate(x.assign(0.))
            # within auto cast scope.
            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float16):
                # assign still expect float32 value even if in float16 scope
                run_and_check()
    def test_assign_op(self, distribution):
        with distribution.scope():
            x = get_var(0., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)

            @def_function.function
            def func():
                self.assertIsNotNone(x.assign(1.0).op)
                self.assertIsNotNone(x.assign_add(1.0).op)
                self.assertIsNotNone(x.assign_sub(1.0).op)

            func()
    def test_checkpoint(self, distribute):
        with get_distribute_scope(distribute):
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
        self.evaluate(x.initializer)
        self.evaluate(x.assign(123.))

        checkpoint = trackable_utils.Checkpoint(x=x)
        prefix = os.path.join(self.get_temp_dir(), 'ckpt')
        save_path = checkpoint.save(prefix)
        self.evaluate(x.assign(234.))
        checkpoint.restore(save_path).assert_consumed().run_restore_ops()
        self.assertEqual(self.evaluate(x), 123.)
    def test_sparse_reads(self):
        x = get_var([1., 2], dtypes.float32)
        # DistributedVariables do not support sparse_read or gather_nd, so we pass
        # distribute=False
        x = autocast_variable.create_autocast_variable(x)
        self.evaluate(x.initializer)

        self.assertEqual(x.sparse_read([0]).dtype, dtypes.float32)
        self.assertEqual(x.gather_nd([0]).dtype, dtypes.float32)

        with autocast_variable.enable_auto_cast_variables(dtypes.float16):
            self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
            self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
    def test_read_nested_scopes(self, distribution):
        with distribution.scope():
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)

            with autocast_variable.enable_auto_cast_variables(dtypes.float16):
                self.assertEqual(x.read_value().dtype, dtypes.float16)

                with autocast_variable.enable_auto_cast_variables(
                        dtypes.float32):
                    self.assertEqual(x.read_value().dtype, dtypes.float32)

                self.assertEqual(x.read_value().dtype, dtypes.float16)
Example #14
0
  def test_dtype_is_not_string(self, distribute):
    with get_distribute_scope(distribute):
      x = get_var(1., dtypes.float32)
      x = autocast_variable.create_autocast_variable(x)
      self.assertEqual(x.dtype, dtypes.float32)
      self.assertIsInstance(x.dtype, dtypes.DType)
      self.assertEqual(x.true_dtype, dtypes.float32)
      self.assertIsInstance(x.true_dtype, dtypes.DType)

      with ops.get_default_graph()._enable_auto_casting_variables('float16'):
        self.assertEqual(x.dtype, dtypes.float16)
        self.assertIsInstance(x.dtype, dtypes.DType)
        self.assertEqual(x.true_dtype, dtypes.float32)
        self.assertIsInstance(x.true_dtype, dtypes.DType)
    def test_dtype_is_not_string(self, distribution):
        with distribution.scope():
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.assertEqual(x.dtype, dtypes.float32)
            self.assertIsInstance(x.dtype, dtypes.DType)
            self.assertEqual(x.true_dtype, dtypes.float32)
            self.assertIsInstance(x.true_dtype, dtypes.DType)

            dtype = dtypes.float16
            with autocast_variable.enable_auto_cast_variables(dtype):
                self.assertEqual(x.dtype, dtypes.float32)
                self.assertIsInstance(x.dtype, dtypes.DType)
                self.assertEqual(x.true_dtype, dtypes.float32)
                self.assertIsInstance(x.true_dtype, dtypes.DType)
    def test_assign_tf_function(self, distribution):
        if not context.executing_eagerly():
            self.skipTest('Test is not compatible with graph mode')

        with distribution.scope():
            x = get_var(0., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)

            @def_function.function
            def run_assign():
                return x.assign(1.).assign_add(3.).assign_add(3.).assign_sub(
                    2.)

            with autocast_variable.enable_auto_cast_variables(dtypes.float16):
                self.assertAllClose(5., self.evaluate(run_assign()))
 def test_repr_distributed(self):
     strategy = mirrored_strategy.MirroredStrategy(['/cpu:1', '/cpu:2'])
     with strategy.scope():
         x = get_var(1., dtypes.float32)
         x = autocast_variable.create_autocast_variable(x)
         use_policy = getattr(strategy.extended, '_use_policy', False)
         if use_policy:
             self.assertRegex(
                 repr(x).replace('\n', ' '),
                 '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
                 'inner_variable=DistributedVariable.*>')
         else:
             self.assertRegex(
                 repr(x).replace('\n', ' '),
                 '<AutoCastDistributedVariable dtype=float32 true_dtype=float32 '
                 'inner_variable=MirroredVariable.*>')
    def test_tf_function_control_dependencies(self, distribution):
        if not context.executing_eagerly():
            self.skipTest('Test is not compatible with graph mode')

        with distribution.scope():
            x = get_var(0., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)

            @def_function.function
            def func():
                update = x.assign_add(1.)
                with ops.control_dependencies([update]):
                    x.assign_add(1.)

            func()
            self.assertAllClose(2., self.evaluate(x))
    def test_optimizer(self, optimizer_class):
        x = get_var(1., dtypes.float32)
        x = autocast_variable.create_autocast_variable(x)
        opt = optimizer_class(1.)

        @def_function.function
        def f():
            opt.minimize(lambda: x + 1., var_list=[x])

        if context.executing_eagerly():
            f()
        else:
            op = f()  # pylint: disable=assignment-from-no-return
            self.evaluate(variables.global_variables_initializer())
            self.evaluate(op)
        self.assertEqual(self.evaluate(x), 0)
    def test_read_nested_scopes(self, distribute):
        with get_distribute_scope(distribute):
            x = get_var(1., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)

            with ops.get_default_graph()._enable_auto_casting_variables(
                    dtypes.float16):
                self.assertEqual(x.dtype, dtypes.float16)
                self.assertEqual(x.read_value().dtype, dtypes.float16)

                with ops.get_default_graph()._enable_auto_casting_variables(
                        dtypes.float32):
                    self.assertEqual(x.dtype, dtypes.float32)
                    self.assertEqual(x.read_value().dtype, dtypes.float32)

                self.assertEqual(x.dtype, dtypes.float16)
                self.assertEqual(x.read_value().dtype, dtypes.float16)
    def test_thread_local_autocast_dtype(self):
        x = get_var(1., dtypes.float32)
        x = autocast_variable.create_autocast_variable(x)
        self.evaluate(x.initializer)

        with autocast_variable.enable_auto_cast_variables(dtypes.float16):
            self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)

            # New threads should not see the modified value of the autocast dtype.
            var_dtype = None

            def f():
                nonlocal var_dtype
                var_dtype = x._cast_dtype

            thread = threading.Thread(target=f)
            thread.start()
            thread.join()
            self.assertEqual(var_dtype, dtypes.float32)
    def test_method_delegations(self, distribution):
        # Test AutoCastVariable correctly delegates Variable methods to the
        # underlying variable.
        with self.test_session(), distribution.scope():
            for read_dtype in (dtypes.float32, dtypes.float16):
                if ds_context.has_strategy():
                    # MirroredVariable.assign will (incorrectly) return a Mirrored value
                    # instead of a MirroredVariable. So we cannot properly wrap it in an
                    # AutoCastVariable.
                    evaluate = self.evaluate
                else:

                    def evaluate(var):
                        self.assertIsInstance(
                            var, autocast_variable.AutoCastVariable)
                        self.assertEqual(
                            array_ops.identity(var).dtype, read_dtype)  # pylint: disable=cell-var-from-loop
                        return self.evaluate(var)

                x = get_var(7., dtypes.float32)
                x = autocast_variable.create_autocast_variable(x)
                with autocast_variable.enable_auto_cast_variables(read_dtype):
                    self.evaluate(x.initializer)
                    self.assertEqual(self.evaluate(x.value()), 7)
                    self.assertEqual(self.evaluate(x.read_value()), 7)
                    self.assertTrue(x.trainable)
                    self.assertEqual(x.synchronization,
                                     x._variable.synchronization)
                    self.assertEqual(x.aggregation, x._variable.aggregation)
                    self.assertEqual(self.evaluate(x.initialized_value()), 7)
                    if not context.executing_eagerly():
                        if not ds_context.has_strategy():
                            # These functions are not supported for DistributedVariables
                            x.load(9)
                            self.assertEqual(x.eval(), 9)
                        self.assertEqual(self.evaluate(x.initial_value), 7)
                        self.assertEqual(x.op, x._variable.op)
                        self.assertEqual(x.graph, x._variable.graph)
                    if not ds_context.has_strategy():
                        # These attributes are not supported for DistributedVariables
                        self.assertIsNone(x.constraint)
                        self.assertEqual(x.initializer,
                                         x._variable.initializer)
                    self.assertEqual(evaluate(x.assign(8)), 8)
                    self.assertEqual(evaluate(x.assign_add(2)), 10)
                    self.assertEqual(evaluate(x.assign_sub(3)), 7)
                    self.assertEqual(x.name, x._variable.name)
                    self.assertEqual(x.device, x._variable.device)
                    self.assertEqual(x.shape, ())
                    self.assertEqual(x.get_shape(), ())

                if not ds_context.has_strategy():
                    # Test scatter_* methods. These are not supported for
                    # DistributedVariables
                    x = get_var([7, 8], dtypes.float32)
                    x = autocast_variable.create_autocast_variable(x)
                    with autocast_variable.enable_auto_cast_variables(
                            read_dtype):
                        self.evaluate(x.initializer)
                        self.assertAllEqual(self.evaluate(x.value()), [7, 8])

                        def slices(val, index):
                            return indexed_slices.IndexedSlices(
                                values=constant_op.constant(
                                    val, dtype=dtypes.float32),
                                indices=constant_op.constant(
                                    index, dtype=dtypes.int32),
                                dense_shape=constant_op.constant(
                                    [2], dtype=dtypes.int32))

                        self.assertAllEqual(
                            evaluate(x.scatter_sub(slices(1., 0))), [6, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_add(slices(1., 0))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_max(slices(9., 1))), [7, 9])
                        self.assertAllEqual(
                            evaluate(x.scatter_min(slices(8., 1))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_mul(slices(2., 1))), [7, 16])
                        self.assertAllEqual(
                            evaluate(x.scatter_div(slices(2., 1))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_update(slices(4., 1))), [7, 4])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])),
                            [6, 2])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])),
                            [7, 4])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_update([[0], [1]],
                                                         [1., 2.])), [1, 2])
    def test_method_delegations(self, distribute):
        # Test AutoCastVariable correctly delegates Variable methods to the
        # underlying variable.
        with get_distribute_scope(distribute):
            evaluate = self.evaluate
            for read_dtype in (dtypes.float32, dtypes.float16):
                x = get_var(7., dtypes.float32)
                x = autocast_variable.create_autocast_variable(x)
                with ops.get_default_graph()._enable_auto_casting_variables(
                        read_dtype):
                    evaluate(x.initializer)
                    self.assertEqual(evaluate(x.value()), 7)
                    self.assertEqual(evaluate(x.read_value()), 7)
                    self.assertTrue(x.trainable)
                    self.assertEqual(x.synchronization,
                                     x._variable.synchronization)
                    self.assertEqual(x.aggregation, x._variable.aggregation)
                    self.assertEqual(evaluate(x.initialized_value()), 7)
                    if not context.executing_eagerly():
                        if not distribute:
                            # These functions are not supported for DistributedVariables
                            x.load(9)
                            self.assertEqual(x.eval(), 9)
                        self.assertEqual(evaluate(x.initial_value), 7)
                        self.assertEqual(x.op, x._variable.op)
                        self.assertEqual(x.graph, x._variable.graph)
                    if not distribute:
                        # These attributes are not supported for DistributedVariables
                        self.assertIsNone(x.constraint)
                        self.assertEqual(x.initializer,
                                         x._variable.initializer)
                    self.assertEqual(evaluate(x.assign(8)), 8)
                    self.assertEqual(evaluate(x.assign_add(2)), 10)
                    self.assertEqual(evaluate(x.assign_sub(3)), 7)
                    self.assertEqual(x.name, x._variable.name)
                    self.assertEqual(x.device, x._variable.device)
                    self.assertEqual(x.shape, ())
                    self.assertEqual(x.get_shape(), ())

                if not distribute:
                    # Test scatter_* methods. These are not supported for
                    # DistributedVariables
                    x = get_var([7, 8], dtypes.float32)
                    x = autocast_variable.create_autocast_variable(x)
                    with ops.get_default_graph(
                    )._enable_auto_casting_variables(read_dtype):
                        evaluate(x.initializer)
                        self.assertAllEqual(evaluate(x.value()), [7, 8])

                        def slices(val, index):
                            return indexed_slices.IndexedSlices(
                                values=constant_op.constant(
                                    val, dtype=dtypes.float32),
                                indices=constant_op.constant(
                                    index, dtype=dtypes.int32),
                                dense_shape=constant_op.constant(
                                    [2], dtype=dtypes.int32))

                        self.assertAllEqual(
                            evaluate(x.scatter_sub(slices(1., 0))), [6, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_add(slices(1., 0))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_max(slices(9., 1))), [7, 9])
                        self.assertAllEqual(
                            evaluate(x.scatter_min(slices(8., 1))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_mul(slices(2., 1))), [7, 16])
                        self.assertAllEqual(
                            evaluate(x.scatter_div(slices(2., 1))), [7, 8])
                        self.assertAllEqual(
                            evaluate(x.scatter_update(slices(4., 1))), [7, 4])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_sub([[0], [1]], [1., 2.])),
                            [6, 2])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_add([[0], [1]], [1., 2.])),
                            [7, 4])
                        self.assertAllEqual(
                            evaluate(x.scatter_nd_update([[0], [1]],
                                                         [1., 2.])), [1, 2])
    def test_assign(self, distribution):
        with distribution.scope():
            x = get_var(0., dtypes.float32)
            x = autocast_variable.create_autocast_variable(x)
            self.evaluate(x.initializer)

            # outside of auto cast scope.
            v1 = constant_op.constant(3., dtype=dtypes.float32)
            v2 = constant_op.constant(3., dtype=dtypes.float16)

            def run_and_check():
                # Assign float32 values
                self.assertAllClose(3., self.evaluate(x.assign(v1)))
                self.assertAllClose(3. * 2, self.evaluate(x.assign_add(v1)))
                self.assertAllClose(3., self.evaluate(x.assign_sub(v1)))

                # Attempt to assign float16 values
                with self.assertRaisesRegex(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign(v2))
                with self.assertRaisesRegex(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign_add(v2))
                with self.assertRaisesRegex(
                        ValueError,
                        'conversion requested dtype float32 for Tensor with dtype float16'
                ):
                    self.evaluate(x.assign_sub(v2))

                # Assign Python floats
                self.assertAllClose(0., self.evaluate(x.assign(0.)))
                self.assertAllClose(3., self.evaluate(x.assign(3.)))
                self.assertAllClose(3. * 2, self.evaluate(x.assign_add(3.)))
                self.assertAllClose(3., self.evaluate(x.assign_sub(3.)))

                # Assign multiple times
                # This currently doesn't work in graph mode if a strategy is used
                if not ds_context.has_strategy() or context.executing_eagerly(
                ):
                    assign = x.assign(1.)
                    self.assertAllClose(1., self.evaluate(assign))
                    self.assertAllClose(0., self.evaluate(assign.assign(0.)))
                    assign_add = x.assign_add(3.)
                    self.assertAllClose(3., self.evaluate(assign_add))
                    self.assertAllClose(
                        3. * 3, self.evaluate(x.assign_add(3.).assign_add(3.)))
                    self.assertAllClose(3. * 3, x)
                    assign_sub = x.assign_sub(3.)
                    self.assertAllClose(3. * 2, self.evaluate(assign_sub))
                    self.assertAllClose(
                        0., self.evaluate(x.assign_sub(3.).assign_sub(3.)))

                # Assign with read_value=False
                self.assertIsNone(self.evaluate(x.assign(1.,
                                                         read_value=False)))
                self.assertAllClose(1., self.evaluate(x))
                self.assertIsNone(
                    self.evaluate(x.assign_add(2., read_value=False)))
                self.assertAllClose(3., self.evaluate(x))
                self.assertIsNone(
                    self.evaluate(x.assign_sub(3., read_value=False)))
                self.assertAllClose(0., self.evaluate(x))

                # Use the tf.assign functions instead of the var.assign methods.
                self.assertAllClose(0., self.evaluate(state_ops.assign(x, 0.)))
                self.assertAllClose(3., self.evaluate(state_ops.assign(x, 3.)))
                self.assertAllClose(3. * 2,
                                    self.evaluate(state_ops.assign_add(x, 3.)))
                self.assertAllClose(3.,
                                    self.evaluate(state_ops.assign_sub(x, 3.)))

            run_and_check()
            # reset x
            self.evaluate(x.assign(0.))
            # within auto cast scope.
            with autocast_variable.enable_auto_cast_variables(dtypes.float16):
                # assign still expect float32 value even if in float16 scope
                run_and_check()