Beispiel #1
0
class TraceModelCallTest(keras_parameterized.TestCase):
    def _assert_all_close(self, expected, actual):
        if not context.executing_eagerly():
            with self.cached_session() as sess:
                K._initialize_variables(sess)
                self.assertAllClose(expected, actual)
        else:
            self.assertAllClose(expected, actual)

    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    def test_trace_model_outputs(self):
        input_dim = 5 if testing_utils.get_model_type(
        ) == 'functional' else None
        model = testing_utils.get_small_mlp(10, 3, input_dim)
        inputs = array_ops.ones((8, 5))

        if input_dim is None:
            with self.assertRaisesRegex(ValueError,
                                        'input shapes have not been set'):
                saving_utils.trace_model_call(model)
            model._set_inputs(inputs)

        fn = saving_utils.trace_model_call(model)
        signature_outputs = fn(inputs)
        if model.output_names:
            expected_outputs = {model.output_names[0]: model(inputs)}
        else:
            expected_outputs = {'output_1': model(inputs)}

        self._assert_all_close(expected_outputs, signature_outputs)

    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    def test_trace_model_outputs_after_fitting(self):
        input_dim = 5 if testing_utils.get_model_type(
        ) == 'functional' else None
        model = testing_utils.get_small_mlp(10, 3, input_dim)
        model.compile(optimizer='sgd',
                      loss='mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(x=np.random.random((8, 5)).astype(np.float32),
                  y=np.random.random((8, 3)).astype(np.float32),
                  epochs=2)

        inputs = array_ops.ones((8, 5))

        fn = saving_utils.trace_model_call(model)
        signature_outputs = fn(inputs)
        if model.output_names:
            expected_outputs = {model.output_names[0]: model(inputs)}
        else:
            expected_outputs = {'output_1': model(inputs)}

        self._assert_all_close(expected_outputs, signature_outputs)

    @keras_parameterized.run_with_all_model_types(exclude_models='sequential')
    @keras_parameterized.run_all_keras_modes
    def test_trace_multi_io_model_outputs(self):
        input_dim = 5
        num_classes = 3
        num_classes_b = 4
        input_a = keras.layers.Input(shape=(input_dim, ), name='input_a')
        input_b = keras.layers.Input(shape=(input_dim, ), name='input_b')

        dense = keras.layers.Dense(num_classes, name='dense')
        dense2 = keras.layers.Dense(num_classes_b, name='dense2')
        dropout = keras.layers.Dropout(0.5, name='dropout')
        branch_a = [input_a, dense]
        branch_b = [input_b, dense, dense2, dropout]

        model = testing_utils.get_multi_io_model(branch_a, branch_b)

        input_a_np = np.random.random((10, input_dim)).astype(np.float32)
        input_b_np = np.random.random((10, input_dim)).astype(np.float32)

        if testing_utils.get_model_type() == 'subclass':
            with self.assertRaisesRegex(ValueError,
                                        'input shapes have not been set'):
                saving_utils.trace_model_call(model)

        model.compile(optimizer='sgd',
                      loss='mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(x=[
            np.random.random((8, input_dim)).astype(np.float32),
            np.random.random((8, input_dim)).astype(np.float32)
        ],
                  y=[
                      np.random.random((8, num_classes)).astype(np.float32),
                      np.random.random((8, num_classes_b)).astype(np.float32)
                  ],
                  epochs=2)

        fn = saving_utils.trace_model_call(model)
        signature_outputs = fn([input_a_np, input_b_np])
        outputs = model([input_a_np, input_b_np])
        if model.output_names:
            expected_outputs = {
                model.output_names[0]: outputs[0],
                model.output_names[1]: outputs[1]
            }
        else:
            expected_outputs = {'output_1': outputs[0], 'output_2': outputs[1]}
        self._assert_all_close(expected_outputs, signature_outputs)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_trace_features_layer(self):
        columns = [feature_column_lib.numeric_column('x')]
        model = sequential.Sequential([dense_features.DenseFeatures(columns)])
        model_input = {'x': constant_op.constant([[1.]])}
        model.predict(model_input, steps=1)
        fn = saving_utils.trace_model_call(model)
        self.assertAllClose({'output_1': [[1.]]}, fn({'x': [[1.]]}))

        columns = [
            feature_column_lib.numeric_column('x'),
            feature_column_lib.numeric_column('y')
        ]
        model = sequential.Sequential([dense_features.DenseFeatures(columns)])
        model_input = {
            'x': constant_op.constant([[1.]]),
            'y': constant_op.constant([[2.]])
        }
        model.predict(model_input, steps=1)
        fn = saving_utils.trace_model_call(model)
        self.assertAllClose({'output_1': [[1., 2.]]},
                            fn({
                                'x': [[1.]],
                                'y': [[2.]]
                            }))

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_specify_input_signature(self):
        model = testing_utils.get_small_sequential_mlp(10, 3, None)
        inputs = array_ops.ones((8, 5))

        with self.assertRaisesRegex(ValueError,
                                    'input shapes have not been set'):
            saving_utils.trace_model_call(model)

        fn = saving_utils.trace_model_call(
            model,
            [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
        signature_outputs = fn(inputs)
        if model.output_names:
            expected_outputs = {model.output_names[0]: model(inputs)}
        else:
            expected_outputs = {'output_1': model(inputs)}
        self._assert_all_close(expected_outputs, signature_outputs)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_subclassed_model_with_input_signature(self):
        class Model(keras.Model):
            def __init__(self):
                super(Model, self).__init__()
                self.dense = keras.layers.Dense(3, name='dense')

            @def_function.function(
                input_signature=[[
                    tensor_spec.TensorSpec([None, 5], dtypes.float32),
                    tensor_spec.TensorSpec([None], dtypes.float32)
                ]], )
            def call(self, inputs, *args):
                x, y = inputs
                return self.dense(x) + y

        model = Model()
        fn = saving_utils.trace_model_call(model)
        x = array_ops.ones((8, 5), dtype=dtypes.float32)
        y = array_ops.ones((3, ), dtype=dtypes.float32)
        expected_outputs = {'output_1': model([x, y])}
        signature_outputs = fn([x, y])
        self._assert_all_close(expected_outputs, signature_outputs)

    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    def test_model_with_fixed_input_dim(self):
        """Ensure that the batch_dim is removed when saving.

    When serving or retraining, it is important to reset the batch dim.
    This can be an issue inside of tf.function. See b/132783590 for context.
    """
        model = testing_utils.get_small_mlp(10, 3, 5)

        loss_object = keras.losses.MeanSquaredError()
        optimizer = gradient_descent.SGD()

        @def_function.function
        def train_step(data, labels):
            with backprop.GradientTape() as tape:
                predictions = model(data)
                loss = loss_object(labels, predictions)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))

        x = np.random.random((8, 5))
        y = np.random.random((8, 3))

        train_step(x, y)

        fn = saving_utils.trace_model_call(model)
        self.assertEqual(fn.input_signature[0].shape.as_list(),
                         tensor_shape.TensorShape([None, 5]).as_list())
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.framework import ops
from tensorflow.python.keras import combinations
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test


@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class RemoveSqueezableTest(test_util.TensorFlowTestCase):
    """Test remove_squeezable_dimensions"""
    def test_ragged_3d_same_shape(self):
        """ shape (2, (sequence={1, 2}), 3)"""
        x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]])
        rank = x.shape.ndims
        x_p, _ = losses_utils.remove_squeezable_dimensions(x, x)
        self.assertEqual(x_p.shape.ndims, rank)

    def test_ragged_3d_4d_squeezable(self):
        """ shapes:

        x: (2, (sequence={1, 2}), 3)
        y: (2, (sequence={1, 2}), 3, 1)
    """
Beispiel #3
0
class RMSpropOptimizerTest(test.TestCase, parameterized.TestCase):
    def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
                              epsilon, centered):
        rms_t = rms * rho + (1 - rho) * g * g
        if centered:
            mg_t = mg * rho + (1 - rho) * g
            denom_t = rms_t - mg_t * mg_t
        else:
            mg_t = mg
            denom_t = rms_t
        if momentum > 0.:
            mom_t = momentum * mom + lr * g / (np.sqrt(denom_t + epsilon))
            var_t = var - mom_t
        else:
            mom_t = mom
            var_t = var - lr * g / (np.sqrt(denom_t) + epsilon)
        return var_t, mg_t, rms_t, mom_t

    def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
                                     lr, rho, momentum, epsilon, centered):
        mg_t = copy.deepcopy(mg)
        rms_t = copy.deepcopy(rms)
        mom_t = copy.deepcopy(mom)
        var_t = copy.deepcopy(var)
        for i in range(len(gindexs)):
            gindex = gindexs[i]
            gvalue = gvalues[i]
            rms_t[gindex] = rms[gindex] * rho + (1 - rho) * gvalue * gvalue
            if centered:
                mg_t[gindex] = mg_t[gindex] * rho + (1 - rho) * gvalue
                denom_t = rms_t[gindex] - mg_t[gindex] * mg_t[gindex]
            else:
                denom_t = rms_t[gindex]
            if momentum > 0.:
                mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(
                    denom_t + epsilon)
                var_t[gindex] = var[gindex] - mom_t[gindex]
            else:
                mom_t[gindex] = mom[gindex]
                var_t[gindex] = var[gindex] - lr * gvalue / (np.sqrt(denom_t) +
                                                             epsilon)
        return var_t, mg_t, rms_t, mom_t

    def testDense(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        for (dtype, learning_rate, rho, momentum, epsilon,
             centered) in _TESTPARAMS:
            with ops.get_default_graph().as_default(), testing_utils.use_gpu():
                # Initialize variables for numpy implementation.
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)

                var0 = variables.Variable(var0_np, dtype=dtype)
                var1 = variables.Variable(var1_np, dtype=dtype)
                grads0 = constant_op.constant(grads0_np, dtype=dtype)
                grads1 = constant_op.constant(grads1_np, dtype=dtype)
                opt = rmsprop.RMSprop(learning_rate=learning_rate,
                                      rho=rho,
                                      momentum=momentum,
                                      epsilon=epsilon,
                                      centered=centered)

                update = opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

                if centered:
                    mg0 = opt.get_slot(var0, "mg")
                    mg1 = opt.get_slot(var1, "mg")
                else:
                    mg0 = None
                    mg1 = None

                if momentum > 0.:
                    mom0 = opt.get_slot(var0, "momentum")
                    mom1 = opt.get_slot(var1, "momentum")
                else:
                    mom0 = None
                    mom1 = None

                rms0 = opt.get_slot(var0, "rms")
                self.assertIsNotNone(rms0)
                rms1 = opt.get_slot(var1, "rms")
                self.assertIsNotNone(rms1)

                mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                rms0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                rms1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)

                # Fetch params to validate initial values
                self.assertAllClose([1.0, 2.0], self.evaluate(var0))
                self.assertAllClose([3.0, 4.0], self.evaluate(var1))

                # Run 3 steps of RMSprop
                for _ in range(1, 4):
                    self.evaluate(update)

                    var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
                        var0_np, grads0_np, mg0_np, rms0_np, mom0_np,
                        learning_rate, rho, momentum, epsilon, centered)
                    var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
                        var1_np, grads1_np, mg1_np, rms1_np, mom1_np,
                        learning_rate, rho, momentum, epsilon, centered)

                    # Validate updated params
                    if centered:
                        self.assertAllCloseAccordingToType(
                            mg0_np, self.evaluate(mg0))
                        self.assertAllCloseAccordingToType(
                            mg1_np, self.evaluate(mg1))
                    if momentum > 0.:
                        self.assertAllCloseAccordingToType(
                            mom0_np, self.evaluate(mom0))
                        self.assertAllCloseAccordingToType(
                            mom1_np, self.evaluate(mom1))
                    self.assertAllCloseAccordingToType(rms0_np,
                                                       self.evaluate(rms0))
                    self.assertAllCloseAccordingToType(rms1_np,
                                                       self.evaluate(rms1))
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))
                    self.assertAllCloseAccordingToType(var1_np,
                                                       self.evaluate(var1))

    def testDenseWithLearningRateDecay(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            var0_np = np.array([1.0, 2.0])
            grads0_np = np.array([0.1, 0.2])
            var1_np = np.array([3.0, 4.0])
            grads1_np = np.array([0.01, 0.2])

            var0 = variables.Variable(var0_np)
            var1 = variables.Variable(var1_np)
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)
            learning_rate = 0.01
            rho = 0.9
            momentum = 0.0
            epsilon = 1e-7
            centered = False
            decay = 0.5
            opt = rmsprop.RMSprop(learning_rate=learning_rate,
                                  rho=rho,
                                  momentum=momentum,
                                  epsilon=epsilon,
                                  centered=centered,
                                  decay=decay)

            update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())

            rms0 = opt.get_slot(var0, "rms")
            self.assertIsNotNone(rms0)
            rms1 = opt.get_slot(var1, "rms")
            self.assertIsNotNone(rms1)
            if momentum > 0.:
                mom0 = opt.get_slot(var0, "momentum")
                mom1 = opt.get_slot(var1, "momentum")
            else:
                mom0 = None
                mom1 = None

            mg0_np = np.array([0.0, 0.0])
            mg1_np = np.array([0.0, 0.0])
            rms0_np = np.array([0.0, 0.0])
            rms1_np = np.array([0.0, 0.0])
            mom0_np = np.array([0.0, 0.0])
            mom1_np = np.array([0.0, 0.0])

            # Fetch params to validate initial values
            self.assertAllClose([1.0, 2.0], self.evaluate(var0))
            self.assertAllClose([3.0, 4.0], self.evaluate(var1))

            # Run 4 steps of RMSprop
            for t in range(2):
                self.evaluate(update)

                lr = learning_rate / (1 + decay * t)
                var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
                    var0_np, grads0_np, mg0_np, rms0_np, mom0_np, lr, rho,
                    momentum, epsilon, centered)
                var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
                    var1_np, grads1_np, mg1_np, rms1_np, mom1_np, lr, rho,
                    momentum, epsilon, centered)

                # Validate updated params
                self.assertAllCloseAccordingToType(rms0_np,
                                                   self.evaluate(rms0))
                self.assertAllCloseAccordingToType(rms1_np,
                                                   self.evaluate(rms1))
                if momentum > 0.:
                    self.assertAllCloseAccordingToType(mom0_np,
                                                       self.evaluate(mom0))
                    self.assertAllCloseAccordingToType(mom1_np,
                                                       self.evaluate(mom1))
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    def testDenseWithLearningRateInverseTimeDecay(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            var0_np = np.array([1.0, 2.0])
            grads0_np = np.array([0.1, 0.2])
            var1_np = np.array([3.0, 4.0])
            grads1_np = np.array([0.01, 0.2])

            var0 = variables.Variable(var0_np)
            var1 = variables.Variable(var1_np)
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)
            learning_rate = 0.01
            rho = 0.9
            momentum = 0.0
            epsilon = 1e-7
            centered = False
            decay = 0.5
            lr_schedule = learning_rate_schedule.InverseTimeDecay(
                learning_rate, decay_steps=1.0, decay_rate=decay)
            opt = rmsprop.RMSprop(learning_rate=lr_schedule,
                                  rho=rho,
                                  momentum=momentum,
                                  epsilon=epsilon,
                                  centered=centered)

            update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())

            rms0 = opt.get_slot(var0, "rms")
            self.assertIsNotNone(rms0)
            rms1 = opt.get_slot(var1, "rms")
            self.assertIsNotNone(rms1)
            if momentum > 0.:
                mom0 = opt.get_slot(var0, "momentum")
                mom1 = opt.get_slot(var1, "momentum")
            else:
                mom0 = None
                mom1 = None

            mg0_np = np.array([0.0, 0.0])
            mg1_np = np.array([0.0, 0.0])
            rms0_np = np.array([0.0, 0.0])
            rms1_np = np.array([0.0, 0.0])
            mom0_np = np.array([0.0, 0.0])
            mom1_np = np.array([0.0, 0.0])

            # Fetch params to validate initial values
            self.assertAllClose([1.0, 2.0], self.evaluate(var0))
            self.assertAllClose([3.0, 4.0], self.evaluate(var1))

            # Run 4 steps of RMSprop
            for t in range(2):
                self.evaluate(update)

                lr = learning_rate / (1 + decay * t)
                var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
                    var0_np, grads0_np, mg0_np, rms0_np, mom0_np, lr, rho,
                    momentum, epsilon, centered)
                var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
                    var1_np, grads1_np, mg1_np, rms1_np, mom1_np, lr, rho,
                    momentum, epsilon, centered)

                # Validate updated params
                self.assertAllCloseAccordingToType(rms0_np,
                                                   self.evaluate(rms0))
                self.assertAllCloseAccordingToType(rms1_np,
                                                   self.evaluate(rms1))
                if momentum > 0.:
                    self.assertAllCloseAccordingToType(mom0_np,
                                                       self.evaluate(mom0))
                    self.assertAllCloseAccordingToType(mom1_np,
                                                       self.evaluate(mom1))
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    def testMinimizeSparseResourceVariable(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0 = variables.Variable([[1.0, 2.0]], dtype=dtype)
                x = constant_op.constant([[4.0], [5.0]], dtype=dtype)

                def loss():
                    pred = math_ops.matmul(
                        embedding_ops.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
                    return pred * pred

                sgd_op = rmsprop.RMSprop(learning_rate=1.0,
                                         rho=0.0,
                                         momentum=0.0,
                                         epsilon=0.0,
                                         centered=False).minimize(
                                             loss, var_list=[var0])
                self.evaluate(variables.global_variables_initializer())
                # Fetch params to validate initial values
                self.assertAllCloseAccordingToType([[1.0, 2.0]],
                                                   self.evaluate(var0))
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[0., 1.]],
                                                   self.evaluate(var0),
                                                   atol=0.01)

    def testMinimizeSparseResourceVariableCentered(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                if test_util.is_xla_enabled() and dtype.is_complex:
                    self.skipTest("b/143578550")
                var0 = variables.Variable([[1.0, 2.0]], dtype=dtype)
                x = constant_op.constant([[4.0], [5.0]], dtype=dtype)

                def loss():
                    pred = math_ops.matmul(
                        embedding_ops.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
                    return pred * pred

                # loss = lambda: pred * pred  # pylint: disable=cell-var-from-loop
                sgd_op = rmsprop.RMSprop(learning_rate=1.0,
                                         rho=0.0,
                                         momentum=0.0,
                                         epsilon=1.0,
                                         centered=True).minimize(
                                             loss, var_list=[var0])
                self.evaluate(variables.global_variables_initializer())
                # Fetch params to validate initial values
                self.assertAllCloseAccordingToType([[1.0, 2.0]],
                                                   self.evaluate(var0))
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[-111, -138]],
                                                   self.evaluate(var0),
                                                   atol=0.01)

    def testSparse(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        for (dtype, learning_rate, rho, momentum, epsilon,
             centered) in _TESTPARAMS:
            with ops.get_default_graph().as_default(), testing_utils.use_gpu():
                # Initialize variables for numpy implementation.
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)

                var0 = variables.Variable(var0_np)
                var1 = variables.Variable(var1_np)
                grads0_np_indices = np.array([0], dtype=np.int32)
                grads0 = ops.IndexedSlices(
                    constant_op.constant(grads0_np),
                    constant_op.constant(grads0_np_indices),
                    constant_op.constant([1]))
                grads1_np_indices = np.array([1], dtype=np.int32)
                grads1 = ops.IndexedSlices(
                    constant_op.constant(grads1_np),
                    constant_op.constant(grads1_np_indices),
                    constant_op.constant([1]))
                opt = rmsprop.RMSprop(learning_rate=learning_rate,
                                      rho=rho,
                                      momentum=momentum,
                                      epsilon=epsilon,
                                      centered=centered)
                update = opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

                if centered:
                    mg0 = opt.get_slot(var0, "mg")
                    self.assertEqual(mg0 is not None, centered)
                    mg1 = opt.get_slot(var1, "mg")
                    self.assertEqual(mg1 is not None, centered)
                else:
                    mg0 = None
                    mg1 = None
                rms0 = opt.get_slot(var0, "rms")
                self.assertIsNotNone(rms0)
                rms1 = opt.get_slot(var1, "rms")
                self.assertIsNotNone(rms1)
                if momentum > 0.:
                    mom0 = opt.get_slot(var0, "momentum")
                    mom1 = opt.get_slot(var1, "momentum")
                else:
                    mom0 = None
                    mom1 = None

                mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                rms0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                rms1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
                mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)

                # Fetch params to validate initial values
                self.assertAllClose([1.0, 2.0], self.evaluate(var0))
                self.assertAllClose([3.0, 4.0], self.evaluate(var1))

                # Run 3 steps of RMSprop
                for _ in range(1, 4):
                    self.evaluate(update)

                    var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
                        var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np,
                        mom0_np, learning_rate, rho, momentum, epsilon,
                        centered)
                    var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
                        var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np,
                        mom1_np, learning_rate, rho, momentum, epsilon,
                        centered)

                    # Validate updated params
                    if centered:
                        self.assertAllCloseAccordingToType(
                            mg0_np, self.evaluate(mg0))
                        self.assertAllCloseAccordingToType(
                            mg1_np, self.evaluate(mg1))
                    self.assertAllCloseAccordingToType(rms0_np,
                                                       self.evaluate(rms0))
                    self.assertAllCloseAccordingToType(rms1_np,
                                                       self.evaluate(rms1))
                    if momentum > 0.:
                        self.assertAllCloseAccordingToType(
                            mom0_np, self.evaluate(mom0))
                        self.assertAllCloseAccordingToType(
                            mom1_np, self.evaluate(mom1))
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))
                    self.assertAllCloseAccordingToType(var1_np,
                                                       self.evaluate(var1))

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testCallableParams(self):
        for dtype in _DATA_TYPES:
            var0 = variables.Variable([1.0, 2.0], dtype=dtype)
            var1 = variables.Variable([3.0, 4.0], dtype=dtype)
            grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
            grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)

            learning_rate = lambda: 2.0
            rho = lambda: 0.9
            momentum = lambda: 0.0
            epsilon = 1.0
            opt = rmsprop.RMSprop(learning_rate, rho, momentum, epsilon)

            # Fetch params to validate initial values
            self.assertAllClose([1.0, 2.0], self.evaluate(var0))
            self.assertAllClose([3.0, 4.0], self.evaluate(var1))
            # Step 1: the rms accumulators where 1. So we should see a normal
            # update: v -= grad * learning_rate
            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
            # Check the parameters.
            self.assertAllCloseAccordingToType(
                np.array([
                    1.0 - (0.1 * 2.0 / math.sqrt(0.001 + 1.0)),
                    2.0 - (0.1 * 2.0 / math.sqrt(0.001 + 1.0))
                ]), self.evaluate(var0))
            self.assertAllCloseAccordingToType(
                np.array([
                    3.0 - (0.01 * 2.0 / math.sqrt(0.00001 + 1.0)),
                    4.0 - (0.01 * 2.0 / math.sqrt(0.00001 + 1.0))
                ]), self.evaluate(var1))
            # Step 2: the root mean square accumulators contain the previous update.
            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
            # Check the parameters.
            self.assertAllCloseAccordingToType(
                np.array([
                    1.0 - (0.1 * 2.0 / math.sqrt(0.001 + 1.0)) -
                    (0.1 * 2.0 / math.sqrt(0.001 * 0.9 + 0.001 + 1.0)),
                    2.0 - (0.1 * 2.0 / math.sqrt(0.001 + 1.0)) -
                    (0.1 * 2.0 / math.sqrt(0.001 * 0.9 + 0.001 + 1.0))
                ]), self.evaluate(var0))
            self.assertAllCloseAccordingToType(
                np.array([
                    3.0 - (0.01 * 2.0 / math.sqrt(0.00001 + 1.0)) -
                    (0.01 * 2.0 / math.sqrt(0.00001 * 0.9 + 1e-5 + 1.0)),
                    4.0 - (0.01 * 2.0 / math.sqrt(0.00001 + 1.0)) -
                    (0.01 * 2.0 / math.sqrt(0.00001 * 0.9 + 1e-5 + 1.0))
                ]), self.evaluate(var1))

    def testConstructRMSpropWithLR(self):
        opt = rmsprop.RMSprop(lr=1.0)
        opt_2 = rmsprop.RMSprop(learning_rate=0.1, lr=1.0)
        opt_3 = rmsprop.RMSprop(learning_rate=0.1)
        self.assertIsInstance(opt.lr, variables.Variable)
        self.assertIsInstance(opt_2.lr, variables.Variable)
        self.assertIsInstance(opt_3.lr, variables.Variable)

        self.evaluate(variables.global_variables_initializer())
        self.assertAllClose(self.evaluate(opt.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_3.lr), (0.1))

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testSlotsUniqueEager(self):
        v1 = variables.Variable(1.)
        v2 = variables.Variable(1.)

        opt = rmsprop.RMSprop(1., momentum=0., centered=False)
        opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
        # There should be iteration, and one unique slot variable for v1 and v2.
        self.assertLen(set({id(v) for v in opt.variables()}), 3)
        self.assertEqual(self.evaluate(opt.variables()[0]),
                         self.evaluate(opt.iterations))

        opt = rmsprop.RMSprop(learning_rate=1., momentum=0.2, centered=False)
        opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
        # There should be iteration, and two unique slot variables for v1 and v2.
        self.assertLen(set({id(v) for v in opt.variables()}), 5)
        self.assertEqual(self.evaluate(opt.variables()[0]),
                         self.evaluate(opt.iterations))

        opt = rmsprop.RMSprop(learning_rate=1., momentum=0.2, centered=True)
        opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
        # There should be iteration, and three unique slot variables for v1 and v2
        self.assertLen(set({id(v) for v in opt.variables()}), 7)
        self.assertEqual(self.evaluate(opt.variables()[0]),
                         self.evaluate(opt.iterations))
class BatchNormalizationV2Test(keras_parameterized.TestCase):

  @keras_parameterized.run_all_keras_modes
  def test_basic_batchnorm_v2(self):
    testing_utils.layer_test(
        normalization_v2.BatchNormalization,
        kwargs={'fused': True},
        input_shape=(3, 3, 3, 3))
    testing_utils.layer_test(
        normalization_v2.BatchNormalization,
        kwargs={'fused': None},
        input_shape=(3, 3, 3))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_v2_fused_attribute(self):
    norm = normalization_v2.BatchNormalization()
    self.assertEqual(norm.fused, None)
    inp = keras.layers.Input(shape=(4, 4, 4))
    norm(inp)
    self.assertEqual(norm.fused, True)

    norm = normalization_v2.BatchNormalization()
    self.assertEqual(norm.fused, None)
    inp = keras.layers.Input(shape=(4, 4))
    norm(inp)
    self.assertEqual(norm.fused, False)

    norm = normalization_v2.BatchNormalization()
    self.assertIsNone(norm.fused)
    inp = keras.layers.Input(shape=(4, 4, 4, 4))
    norm(inp)
    self.assertEqual(norm.fused, False)

    norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
    self.assertEqual(norm.fused, False)
    inp = keras.layers.Input(shape=(4, 4, 4))
    norm(inp)
    self.assertEqual(norm.fused, False)

    norm = normalization_v2.BatchNormalization(fused=False)
    self.assertEqual(norm.fused, False)
    inp = keras.layers.Input(shape=(4, 4, 4))
    norm(inp)
    self.assertEqual(norm.fused, False)

    norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
    self.assertEqual(norm.fused, True)
    inp = keras.layers.Input(shape=(4, 4, 4))
    norm(inp)
    self.assertEqual(norm.fused, True)

    with self.assertRaisesRegex(ValueError, 'fused.*renorm'):
      normalization_v2.BatchNormalization(fused=True, renorm=True)

    with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
      normalization_v2.BatchNormalization(fused=True, axis=2)

    with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
      normalization_v2.BatchNormalization(fused=True, axis=[1, 3])

    with self.assertRaisesRegex(ValueError, 'fused.*virtual_batch_size'):
      normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2)

    with self.assertRaisesRegex(ValueError, 'fused.*adjustment'):
      normalization_v2.BatchNormalization(fused=True,
                                          adjustment=lambda _: (1, 0))

    norm = normalization_v2.BatchNormalization(fused=True)
    self.assertEqual(norm.fused, True)
    inp = keras.layers.Input(shape=(4, 4))
    with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
      norm(inp)

  def test_updates_in_wrap_function(self):

    def my_func():
      layer = normalization.BatchNormalization()
      x = array_ops.ones((10, 1))
      y = layer(x, training=True)
      # Updates should be tracked in a `wrap_function`.
      self.assertLen(layer.updates, 2)
      return y

    wrapped_fn = wrap_function.wrap_function(my_func, [])
    wrapped_fn()

  @keras_parameterized.run_all_keras_modes
  def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
    # Test case for GitHub issue for 32380
    norm = normalization_v2.BatchNormalization(virtual_batch_size=8)
    inp = keras.layers.Input(shape=(None, None, 3))
    _ = norm(inp)
class LayerNormalizationTest(keras_parameterized.TestCase):

  @keras_parameterized.run_all_keras_modes
  def test_basic_layernorm(self):
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={
            'gamma_regularizer': keras.regularizers.l2(0.01),
            'beta_regularizer': keras.regularizers.l2(0.01)
        },
        input_shape=(3, 4, 2))
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={
            'gamma_initializer': 'ones',
            'beta_initializer': 'ones',
        },
        input_shape=(3, 4, 2))
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={'scale': False,
                'center': False},
        input_shape=(3, 3))
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={'axis': (-3, -2, -1)},
        input_shape=(2, 8, 8, 3))

  @keras_parameterized.run_all_keras_modes
  def test_non_fused_layernorm(self):
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={'axis': -2},
        input_shape=(3, 4, 2))
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={'axis': (-3, -2)},
        input_shape=(2, 8, 8, 3))
    testing_utils.layer_test(
        keras.layers.LayerNormalization,
        kwargs={'axis': (-3, -1)},
        input_shape=(2, 8, 8, 3))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_layernorm_weights(self):
    layer = keras.layers.LayerNormalization(scale=False, center=False)
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.trainable_weights), 0)
    self.assertEqual(len(layer.weights), 0)

    layer = keras.layers.LayerNormalization()
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.trainable_weights), 2)
    self.assertEqual(len(layer.weights), 2)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_layernorm_regularization(self):
    layer = keras.layers.LayerNormalization(
        gamma_regularizer='l1', beta_regularizer='l1')
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.losses), 2)
    max_norm = keras.constraints.max_norm
    layer = keras.layers.LayerNormalization(
        gamma_constraint=max_norm, beta_constraint=max_norm)
    layer.build((None, 3, 4))
    self.assertEqual(layer.gamma.constraint, max_norm)
    self.assertEqual(layer.beta.constraint, max_norm)

  @keras_parameterized.run_all_keras_modes
  def test_layernorm_convnet_channel_last(self):
    model = keras.models.Sequential()
    norm = keras.layers.LayerNormalization(input_shape=(4, 4, 3))
    model.add(norm)
    model.compile(
        loss='mse',
        optimizer=gradient_descent.GradientDescentOptimizer(0.01),
        run_eagerly=testing_utils.should_run_eagerly())

    # centered on 5.0, variance 10.0
    x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
    model.fit(x, x, epochs=4, verbose=0)
    out = model.predict(x)
    out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 3))
    out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 3))

    np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
    np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)

  @keras_parameterized.run_all_keras_modes
  def test_layernorm_correctness(self):
    _run_layernorm_correctness_test(
        normalization.LayerNormalization, dtype='float32')

  @keras_parameterized.run_all_keras_modes
  def test_layernorm_mixed_precision(self):
    _run_layernorm_correctness_test(
        normalization.LayerNormalization, dtype='float16')

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testIncorrectAxisType(self):
    with self.assertRaisesRegex(TypeError,
                                r'Expected an int or a list/tuple of ints'):
      _ = normalization.LayerNormalization(axis={'axis': -1})

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testInvalidAxis(self):
    with self.assertRaisesRegex(ValueError, r'Invalid axis: 3'):
      layer_norm = normalization.LayerNormalization(axis=3)
      layer_norm.build(input_shape=(2, 2, 2))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testDuplicateAxis(self):
    with self.assertRaisesRegex(ValueError, r'Duplicate axis:'):
      layer_norm = normalization.LayerNormalization(axis=[-1, -1])
      layer_norm.build(input_shape=(2, 2, 2))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testFusedAttr(self):
    layer_norm = normalization.LayerNormalization(axis=[-2, -1])
    layer_norm.build(input_shape=(2, 2, 2))
    self.assertEqual(layer_norm._fused, True)
Beispiel #6
0
class AdadeltaOptimizerTest(test.TestCase, parameterized.TestCase):
    def doTestBasic(self, use_resource=False, use_callable_params=False):
        num_updates = 4  # number of ADADELTA steps to perform
        for dtype in _DATA_TYPES:
            for grad in [0.2, 0.1, 0.01]:
                for lr in [1.0, 0.5, 0.1]:
                    var0_init = [1.0, 2.0]
                    var1_init = [3.0, 4.0]
                    if use_resource:
                        var0 = variables.Variable(var0_init, dtype=dtype)
                        var1 = variables.Variable(var1_init, dtype=dtype)
                    else:
                        var0 = variables.Variable(var0_init, dtype=dtype)
                        var1 = variables.Variable(var1_init, dtype=dtype)

                    grads = constant_op.constant([grad, grad], dtype=dtype)

                    accum = 0.0
                    accum_update = 0.0

                    # ADADELTA gradient optimizer
                    rho = 0.95
                    epsilon = 1e-8
                    if use_callable_params:
                        adadelta_opt = adadelta.Adadelta(
                            learning_rate=lambda: lr,  # pylint: disable=cell-var-from-loop
                            rho=lambda: rho,  # pylint: disable=cell-var-from-loop
                            epsilon=epsilon)  # pylint: disable=cell-var-from-loop
                    else:
                        adadelta_opt = adadelta.Adadelta(learning_rate=lr,
                                                         rho=rho,
                                                         epsilon=epsilon)
                    if not context.executing_eagerly():
                        adadelta_update = adadelta_opt.apply_gradients(
                            zip([grads, grads], [var0, var1]))
                        self.evaluate(variables.global_variables_initializer())

                        # Assign slots
                        slot = [None] * 2
                        slot_update = [None] * 2
                        slot[0] = adadelta_opt.get_slot(var0, "accum_grad")
                        self.assertEqual(slot[0].shape, var0.shape)

                        slot_update[0] = adadelta_opt.get_slot(
                            var0, "accum_var")
                        self.assertEqual(slot_update[0].shape, var0.shape)

                        slot[1] = adadelta_opt.get_slot(var1, "accum_grad")
                        self.assertEqual(slot[1].shape, var1.shape)

                        slot_update[1] = adadelta_opt.get_slot(
                            var1, "accum_var")
                        self.assertEqual(slot_update[1].shape, var1.shape)

                    # Fetch params to validate initial values
                    self.assertAllClose(var0_init, self.evaluate(var0))
                    self.assertAllClose(var1_init, self.evaluate(var1))

                    update = [None] * num_updates
                    tot_update = 0
                    for step in range(num_updates):
                        # Run adadelta update for comparison
                        if not context.executing_eagerly():
                            self.evaluate(adadelta_update)
                        else:
                            adadelta_opt.apply_gradients(
                                zip([grads, grads], [var0, var1]))

                        # Perform initial update without previous accum values
                        accum = accum * rho + (grad**2) * (1 - rho)
                        update[step] = (np.sqrt(accum_update + epsilon) *
                                        (1. / np.sqrt(accum + epsilon)) * grad)
                        accum_update = (accum_update * rho +
                                        (update[step]**2) * (1.0 - rho))
                        tot_update += update[step] * lr

                        if not context.executing_eagerly():
                            # Check that the accumulators have been updated
                            # TODO(lxuechen): This is hard to test in eager mode
                            for slot_idx in range(2):
                                self.assertAllCloseAccordingToType(
                                    np.array([accum, accum],
                                             dtype=dtype.as_numpy_dtype(0)),
                                    self.evaluate(slot[slot_idx]),
                                    rtol=1e-5)

                                self.assertAllCloseAccordingToType(
                                    np.array([accum_update, accum_update],
                                             dtype=dtype.as_numpy_dtype(0)),
                                    self.evaluate(slot_update[slot_idx]),
                                    rtol=1e-5)

                            # Check that the parameters have been updated
                            self.assertAllCloseAccordingToType(
                                np.array([
                                    var0_init[0] - tot_update,
                                    var0_init[1] - tot_update
                                ],
                                         dtype=dtype.as_numpy_dtype(0)),
                                self.evaluate(var0),
                                rtol=1e-5)

                            self.assertAllCloseAccordingToType(
                                np.array([
                                    var1_init[0] - tot_update,
                                    var1_init[1] - tot_update
                                ],
                                         dtype=dtype.as_numpy_dtype(0)),
                                self.evaluate(var1),
                                rtol=1e-5)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testResourceBasic(self):
        self.doTestBasic(use_resource=True)

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testBasicCallableParams(self):
        self.doTestBasic(use_resource=True, use_callable_params=True)

    def testMinimizeSparseResourceVariable(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0 = variables.Variable([[1.0, 2.0]], dtype=dtype)
                x = constant_op.constant([[4.0], [5.0]], dtype=dtype)

                def loss():
                    pred = math_ops.matmul(
                        embedding_ops.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
                    return pred * pred

                sgd_op = adadelta.Adadelta(1.0, 1.0,
                                           1.0).minimize(loss, var_list=[var0])
                self.evaluate(variables.global_variables_initializer())
                # Fetch params to validate initial values
                self.assertAllCloseAccordingToType([[1.0, 2.0]],
                                                   self.evaluate(var0))
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[-111, -138]],
                                                   self.evaluate(var0))

    def testConstructAdadeltaWithLR(self):
        opt = adadelta.Adadelta(lr=1.0, rho=0.9, epsilon=1.)
        opt_2 = adadelta.Adadelta(learning_rate=0.1,
                                  rho=0.9,
                                  epsilon=1.,
                                  lr=1.0)
        opt_3 = adadelta.Adadelta(learning_rate=0.1, rho=0.9, epsilon=1.)
        self.assertIsInstance(opt.lr, variables.Variable)
        self.assertIsInstance(opt_2.lr, variables.Variable)
        self.assertIsInstance(opt_3.lr, variables.Variable)

        self.evaluate(variables.global_variables_initializer())
        self.assertAllClose(self.evaluate(opt.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_3.lr), (0.1))

    def testConstructAdadeltaWithEpsilonValues(self):
        opt = adadelta.Adadelta(epsilon=None)
        self.assertEqual(opt.epsilon, 1e-7)

        opt = adadelta.Adadelta(epsilon=1e-8)
        self.assertEqual(opt.epsilon, 1e-8)
Beispiel #7
0
class CheckpointingTests(keras_parameterized.TestCase):
    @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
    def testNamingWithOptimizer(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        # A nuisance Model using the same optimizer. Its slot variables should not
        # go in the checkpoint, since it is never depended on.
        other_model = MyModel()
        optimizer = adam.Adam(0.001)
        step = training_util.get_or_create_global_step()
        root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                    model=model,
                                                    step=step)

        with backprop.GradientTape() as tape:
            loss = model(input_value)
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        train_op = control_flow_ops.group(
            optimizer.apply_gradients(zip(gradients, variables)),
            step.assign_add(1))

        with backprop.GradientTape() as tape:
            loss = other_model(input_value)
        variables = other_model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        self.evaluate(trackable_utils.gather_initializers(root_trackable))
        self.evaluate(train_op)
        named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
            root_trackable).serialize_object_graph()
        expected_slot_keys = (
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
        )
        expected_checkpoint_names = (
            # Created in the root node, so no prefix.
            "step",
            "model/_second/kernel",
            "model/_named_dense/kernel",
            "model/_named_dense/bias",
            # non-Layer dependency of the model
            "model/_non_layer/a_variable",
            "optimizer/learning_rate",
            "optimizer/beta_1",
            "optimizer/beta_2",
            "optimizer/iter",
            "optimizer/decay",
        ) + expected_slot_keys
        suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
        expected_checkpoint_names = [
            name + suffix for name in expected_checkpoint_names
        ]
        named_variables = {v.name: v for v in named_variables}
        six.assertCountEqual(self, expected_checkpoint_names,
                             named_variables.keys())
        # Check that we've mapped to the right variable objects (not exhaustive)
        self.assertEqual("global_step",
                         named_variables["step" + suffix].full_name)
        self.assertEqual(
            "my_model/dense_1/kernel",
            named_variables["model/_second/kernel" + suffix].full_name)
        self.assertEqual(
            "my_model/dense/kernel",
            named_variables["model/_named_dense/kernel" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_1",
            named_variables["optimizer/beta_1" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_2",
            named_variables["optimizer/beta_2" + suffix].full_name)
        # Spot check the generated protocol buffers.
        self.assertEqual("optimizer",
                         serialized_graph.nodes[0].children[1].local_name)
        optimizer_node = serialized_graph.nodes[
            serialized_graph.nodes[0].children[1].node_id]
        children = [node.local_name for node in optimizer_node.children]
        six.assertCountEqual(
            self,
            # hyper variable dependencies
            ["beta_1", "beta_2", "iter", "decay", "learning_rate"],
            children)
        serialized_slot_keys = []
        for slot in optimizer_node.slot_variables:
            for attribute in (serialized_graph.nodes[
                    slot.slot_variable_node_id].attributes):
                serialized_slot_keys.append(attribute.checkpoint_key)
        six.assertCountEqual(self,
                             [key + suffix for key in expected_slot_keys],
                             serialized_slot_keys)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testSaveRestore(self):
        with self.test_session():
            model = MyModel()
            optimizer = adam.Adam(0.001)
            root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                        model=model)
            input_value = constant_op.constant([[3.]])
            with backprop.GradientTape() as tape:
                loss = model(input_value)
            variables = model.trainable_variables
            gradients = tape.gradient(loss, variables)
            train_op = optimizer.apply_gradients(zip(gradients, variables))
            self.assertFalse(root_trackable.save_counter.trainable)
            self.evaluate(trackable_utils.gather_initializers(root_trackable))
            self.evaluate(train_op)
            prefix = os.path.join(self.get_temp_dir(), "ckpt")
            self.evaluate(
                state_ops.assign(model._named_dense.variables[1], [42.]))
            m_bias_slot = optimizer.get_slot(model._named_dense.variables[1],
                                             "m")
            self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
            save_path = root_trackable.save(file_prefix=prefix)
            self.evaluate(
                state_ops.assign(model._named_dense.variables[1], [43.]))
            self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
            optimizer_variables = self.evaluate(
                sorted(optimizer.variables(), key=lambda v: v.name))
            self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
            # Immediate restoration
            status = root_trackable.restore(
                save_path=save_path).assert_consumed()
            status.run_restore_ops()
            self.assertAllEqual([42.],
                                self.evaluate(model._named_dense.variables[1]))
            self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
            self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
            if not context.executing_eagerly():
                return  # Restore-on-create is only supported when executing eagerly
            on_create_model = MyModel()
            on_create_optimizer = adam.Adam(0.001)
            on_create_root = trackable_utils.Checkpoint(
                optimizer=on_create_optimizer, model=on_create_model)
            # Deferred restoration
            status = on_create_root.restore(save_path=save_path)
            status.assert_nontrivial_match()
            status.assert_existing_objects_matched()
            with self.assertRaises(AssertionError):
                status.assert_consumed()
            on_create_model(constant_op.constant([[3.]]))  # create variables
            self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
            self.assertAllEqual([42.],
                                self.evaluate(
                                    on_create_model._named_dense.variables[1]))
            on_create_m_bias_slot = on_create_optimizer.get_slot(
                on_create_model._named_dense.variables[1], "m")
            status.assert_existing_objects_matched()
            if not context.executing_eagerly():
                with self.assertRaises(AssertionError):
                    status.assert_consumed()
            # Optimizer slot variables are created when the original variable is
            # restored.
            self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
            dummy_var = variables_lib.Variable([1.])
            on_create_optimizer.minimize(loss=dummy_var.read_value,
                                         var_list=[dummy_var])
            status.assert_existing_objects_matched()
            status.assert_consumed()
            self.assertAllEqual(
                optimizer_variables,
                # Creation order is different, so .variables() needs to be re-sorted.
                self.evaluate(
                    sorted(optimizer.variables(), key=lambda v: v.name)))

    # TODO(allenl): Debug garbage created by this test in python3.
    def testDeferredRestorationUsageEager(self):
        """An idiomatic eager execution example."""
        num_training_steps = 10
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        for training_continuation in range(3):
            model = MyModel()
            optimizer = adam.Adam(0.001)
            root = trackable_utils.Checkpoint(optimizer=optimizer, model=model)
            root.restore(
                checkpoint_management.latest_checkpoint(checkpoint_directory))
            for _ in range(num_training_steps):
                # TODO(allenl): Use a Dataset and serialize/checkpoint it.
                input_value = constant_op.constant([[3.]])
                with backprop.GradientTape() as tape:
                    loss = model(input_value)
                variables = model.trainable_variables
                gradients = tape.gradient(loss, variables)
                optimizer.apply_gradients(zip(gradients, variables))
            root.save(file_prefix=checkpoint_prefix)
            self.assertEqual((training_continuation + 1) * num_training_steps,
                             root.optimizer.iterations.numpy())

    def testUsageGraph(self):
        """Expected usage when graph building."""
        with context.graph_mode():
            num_training_steps = 10
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            for training_continuation in range(3):
                with ops.Graph().as_default():
                    model = MyModel()
                    optimizer = adam.Adam(0.001)
                    root = trackable_utils.CheckpointV1(optimizer=optimizer,
                                                        model=model)
                    input_value = constant_op.constant([[3.]])
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    train_op = optimizer.apply_gradients(
                        zip(gradients, variables))

                    checkpoint_path = checkpoint_management.latest_checkpoint(
                        checkpoint_directory)
                    with self.session(
                            graph=ops.get_default_graph()) as session:
                        status = root.restore(save_path=checkpoint_path)
                        status.initialize_or_restore(session=session)
                        if checkpoint_path is None:
                            self.assertEqual(0, training_continuation)
                            with self.assertRaises(AssertionError):
                                status.assert_consumed()
                            with self.assertRaises(AssertionError):
                                status.assert_existing_objects_matched()
                        else:
                            status.assert_consumed()
                            status.assert_existing_objects_matched()
                        for _ in range(num_training_steps):
                            session.run(train_op)
                        root.save(file_prefix=checkpoint_prefix,
                                  session=session)
                        self.assertEqual(
                            (training_continuation + 1) * num_training_steps,
                            session.run(root.optimizer.iterations))
                        self.assertEqual(training_continuation + 1,
                                         session.run(root.save_counter))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testAgnosticUsage(self):
        """Graph/eager agnostic usage."""
        # Does create garbage when executing eagerly due to ops.Graph() creation.
        with self.test_session():
            num_training_steps = 10
            checkpoint_directory = self.get_temp_dir()
            optimizer = adam.Adam(0.001)

            def _train_fn(model, input_value):
                with backprop.GradientTape() as tape:
                    loss = model(input_value)
                variables = model.trainable_variables
                gradients = tape.gradient(loss, variables)
                return optimizer.apply_gradients(zip(gradients, variables))

            for training_continuation in range(3):
                with testing_utils.device(should_use_gpu=True):
                    model = MyModel()
                    root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                      model=model)
                    manager = checkpoint_management.CheckpointManager(
                        root, checkpoint_directory, max_to_keep=1)
                    status = root.restore(save_path=manager.latest_checkpoint)
                    input_value = constant_op.constant([[3.]])
                    train_fn = functools.partial(_train_fn, model, input_value)
                    if not context.executing_eagerly():
                        train_fn = functools.partial(self.evaluate, train_fn())
                    status.initialize_or_restore()
                    for _ in range(num_training_steps):
                        train_fn()
                    manager.save()
                    self.assertEqual(
                        (training_continuation + 1) * num_training_steps,
                        self.evaluate(root.optimizer.iterations))
                    self.assertEqual(training_continuation + 1,
                                     self.evaluate(root.save_counter))

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testPartialRestoreWarningObject(self):
        optimizer = adam.Adam(0.0)
        original_root = trackable_utils.Checkpoint(
            v1=variables_lib.Variable(2.),
            v2=variables_lib.Variable(3.),
            optimizer=optimizer)
        # Create a slot variable to save
        optimizer.minimize(original_root.v1.read_value, [original_root.v1])
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        save_path = original_root.save(prefix)
        partial_root = trackable_utils.Checkpoint(
            v1=variables_lib.Variable(0.))
        weak_partial_root = weakref.ref(partial_root)
        weak_v1 = weakref.ref(partial_root.v1)
        partial_root.restore(save_path)
        self.assertEqual(2., partial_root.v1.numpy())
        with test.mock.patch.object(logging, "warning") as mock_log:
            del partial_root
            self.assertIsNone(weak_partial_root())
            self.assertIsNone(weak_v1())
            messages = str(mock_log.call_args_list)
        self.assertIn("(root).v2'", messages)
        self.assertIn("(root).optimizer's state 'm' for (root).v1", messages)
        self.assertNotIn("(root).v1'", messages)
        self.assertIn("expect_partial()", messages)

    # pylint: disable=cell-var-from-loop
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testWithDefun(self):
        with self.test_session():
            num_training_steps = 2
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            for training_continuation in range(3):
                with testing_utils.device(should_use_gpu=True):
                    model = MyModel()
                    # Don't actually train so we can test variable values
                    optimizer = adam.Adam(0.)
                    root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                      model=model)
                    checkpoint_path = checkpoint_management.latest_checkpoint(
                        checkpoint_directory)
                    status = root.restore(save_path=checkpoint_path)

                    def train_fn():
                        @def_function.function
                        def _call_model(x):
                            return model(x)

                        with backprop.GradientTape() as tape:
                            loss = _call_model(constant_op.constant([[3.]]))
                        gradients = tape.gradient(loss, model.variables)
                        return optimizer.apply_gradients(
                            zip(gradients, model.variables))

                    if not context.executing_eagerly():
                        train_fn = functools.partial(self.evaluate, train_fn())
                    status.initialize_or_restore()
                    for _ in range(num_training_steps):
                        train_fn()
                    if training_continuation > 0:
                        status.assert_consumed()
                        self.assertAllClose([[42.]],
                                            self.evaluate(model.variables[0]))
                    else:
                        self.evaluate(model.variables[0].assign([[42.]]))
                    root.save(file_prefix=checkpoint_prefix)
                    self.assertEqual(
                        (training_continuation + 1) * num_training_steps,
                        self.evaluate(optimizer.iterations))
                    self.assertEqual(training_continuation + 1,
                                     self.evaluate(root.save_counter))

    # pylint: enable=cell-var-from-loop

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testAnonymousVarsInInit(self):
        class Model(training.Model):
            def __init__(self):
                super(Model, self).__init__()
                self.w = variables_lib.Variable(0.0)
                self.b = variables_lib.Variable(0.0)
                self.vars = [self.w, self.b]

            def call(self, x):
                return x * self.w + self.b

        model = Model()
        optimizer = adam.Adam(learning_rate=0.05)
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        checkpoint = trackable_utils.Checkpoint(model=model,
                                                optimizer=optimizer)
        for _ in range(2):
            checkpoint.save(checkpoint_prefix)
            with backprop.GradientTape() as tape:
                loss = (constant_op.constant(1.) -
                        model(constant_op.constant(1.)))**2
            grad = tape.gradient(loss, model.vars)
            optimizer.apply_gradients([(g, v)
                                       for g, v in zip(grad, model.vars)])

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testDeferredSlotRestoration(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()

            root = trackable_utils.Checkpoint()
            root.var = trackable_utils.add_variable(root,
                                                    name="var",
                                                    initializer=0.)
            optimizer = adam.Adam(0.1)
            variables = [root.var]
            gradients = [1.]
            train_op = optimizer.apply_gradients(zip(gradients, variables))
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                trackable_utils.gather_initializers(
                    trackable_utils.Checkpoint(root=root,
                                               optimizer=optimizer)))
            self.evaluate(train_op)
            self.evaluate(state_ops.assign(root.var, 12.))
            no_slots_path = root.save(
                os.path.join(checkpoint_directory, "no_slots"))
            root.optimizer = optimizer
            self.evaluate(state_ops.assign(root.var, 13.))
            self.evaluate(
                state_ops.assign(
                    optimizer.get_slot(slot_name="m", var=root.var), 14.))
            slots_path = root.save(
                os.path.join(checkpoint_directory, "with_slots"))
            new_root = trackable_utils.Checkpoint()
            # Load the slot-containing checkpoint (deferred), then immediately
            # overwrite the non-slot variable (also deferred).
            slot_status = new_root.restore(slots_path)
            no_slot_status = new_root.restore(no_slots_path)
            with self.assertRaises(AssertionError):
                no_slot_status.assert_consumed()
            new_root.var = trackable_utils.add_variable(new_root,
                                                        name="var",
                                                        shape=[])
            no_slot_status.assert_consumed()
            no_slot_status.run_restore_ops()
            self.assertEqual(12., self.evaluate(new_root.var))
            new_root.optimizer = adam.Adam(0.1)
            slot_status.assert_existing_objects_matched()
            if not context.executing_eagerly():
                with self.assertRaisesRegex(AssertionError,
                                            "Unresolved object"):
                    slot_status.assert_consumed()
            self.assertEqual(12., self.evaluate(new_root.var))
            if context.executing_eagerly():
                # Slot variables are only created with restoring initializers when
                # executing eagerly.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            else:
                # Slot variables are not created eagerly when graph building.
                with self.assertRaises(KeyError):
                    new_root.optimizer.get_slot(slot_name="m",
                                                var=new_root.var)
            variables = [new_root.var]
            gradients = [1.]
            train_op = new_root.optimizer.apply_gradients(
                zip(gradients, variables))
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            if not context.executing_eagerly():
                # The train op hasn't run when graph building, so the slot variable has
                # its restored value. It has run in eager, so the value will
                # be different.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            self.evaluate(train_op)
            slot_status.assert_consumed()

    def testManySavesGraph(self):
        """Saves after the first should not modify the graph."""
        with context.graph_mode():
            graph = ops.Graph()
            with graph.as_default(), self.session(graph):
                checkpoint_directory = self.get_temp_dir()
                checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
                obj = trackable_utils.Checkpoint()
                obj.var = variables_lib.Variable(0., name="v")
                obj.opt = adam.Adam(0.1)
                variables = [obj.var]
                gradients = [1.]
                obj.opt.apply_gradients(zip(gradients, variables))
                self.evaluate(trackable_utils.gather_initializers(obj))
                obj.save(checkpoint_prefix)
                graph.finalize()
                obj.save(checkpoint_prefix)

    def testManyRestoresGraph(self):
        """Restores after the first should not modify the graph."""
        with context.graph_mode():
            graph = ops.Graph()
            with graph.as_default(), self.session(graph):
                checkpoint_directory = self.get_temp_dir()
                checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
                obj = trackable_utils.Checkpoint()
                obj.var = variables_lib.Variable(0., name="v")
                obj.opt = adam.Adam(0.1)
                variables = [obj.var]
                gradients = [1.]
                obj.opt.apply_gradients(zip(gradients, variables))
                self.evaluate(trackable_utils.gather_initializers(obj))
                save_path = obj.save(checkpoint_prefix)
                obj.restore(save_path)
                graph.finalize()
                obj.restore(save_path)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def test_sequential(self):
        with self.test_session():
            model = sequential.Sequential()
            checkpoint = trackable_utils.Checkpoint(model=model)
            model.add(core.Dense(4))
            second_dense = core.Dense(5)
            model.add(second_dense)
            model(constant_op.constant([[1.]]))
            checkpoint.restore(None).initialize_or_restore()
            self.evaluate(
                second_dense.bias.assign(
                    constant_op.constant([1., 2., 3., 4., 5.])))
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            save_path = checkpoint.save(checkpoint_prefix)
            self.evaluate(
                second_dense.bias.assign(
                    constant_op.constant([5., 6., 7., 8., 9.])))
            checkpoint.restore(save_path).assert_consumed().run_restore_ops()
            self.assertAllEqual([1., 2., 3., 4., 5.],
                                self.evaluate(second_dense.bias))

            deferred_sequential = sequential.Sequential()
            deferred_sequential_checkpoint = trackable_utils.Checkpoint(
                model=deferred_sequential)
            status = deferred_sequential_checkpoint.restore(save_path)
            deferred_sequential.add(core.Dense(4))
            deferred_second_dense = core.Dense(5)
            deferred_sequential.add(deferred_second_dense)
            deferred_sequential(constant_op.constant([[1.]]))
            status.run_restore_ops()
            self.assertAllEqual([1., 2., 3., 4., 5.],
                                self.evaluate(deferred_second_dense.bias))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def test_initialize_if_not_restoring(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001)
                root = trackable_utils.Checkpoint(
                    model=model
                )  # Do not save the optimizer with the checkpoint.
                optimizer_checkpoint = trackable_utils.Checkpoint(
                    optimizer=optimizer)

                checkpoint_path = checkpoint_management.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)
                input_value = constant_op.constant([[3.]])

                def train_fn():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                # TODO(tanzheny): Add hyper variables to .variables(), and set them with
                # set_weights etc.
                variables_not_in_the_variables_property = [
                    obj for obj in optimizer._hyper.values()
                    if isinstance(obj, variables_lib.Variable)
                ]
                self.evaluate([
                    v.initializer for v in optimizer.variables() +
                    variables_not_in_the_variables_property
                ])
                train_fn()
                model_save_path = root.save(file_prefix=checkpoint_prefix)
                self.evaluate(optimizer.beta_1.assign(42.))
                optimizer_save_path = optimizer_checkpoint.save(
                    optimizer_only_prefix)
            del train_fn

            # Restore into a graph with the optimizer
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001)
                root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                  model=model)
                status = root.restore(save_path=model_save_path)
                input_value = constant_op.constant([[3.]])

                def train_fn1():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn1 = functools.partial(self.evaluate, train_fn1())
                status.initialize_or_restore()
                train_fn1()
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()
                with self.assertRaises(AssertionError):
                    status.assert_consumed()
            del train_fn1

            # Make sure initialization doesn't clobber later restores
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.Adam(0.001, beta_1=1.0)
                root = trackable_utils.Checkpoint(optimizer=optimizer,
                                                  model=model)
                opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
                status = root.restore(save_path=model_save_path)
                init_only_optimizer_status = opt_root.restore(save_path=None)
                optimizer_status = opt_root.restore(
                    save_path=optimizer_save_path)
                input_value = constant_op.constant([[3.]])

                def train_fn2():
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    return optimizer.apply_gradients(zip(gradients, variables))

                if not context.executing_eagerly():
                    train_fn2 = functools.partial(self.evaluate, train_fn2())
                optimizer_status.run_restore_ops()
                status.initialize_or_restore()
                init_only_optimizer_status.initialize_or_restore()
                train_fn2()
                self.assertEqual(42., self.evaluate(optimizer.beta_1))
Beispiel #8
0
class DenseTest(test.TestCase, parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testDenseProperties(self):
        dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
        self.assertEqual(dense.units, 2)
        self.assertEqual(dense.activation, nn_ops.relu)
        self.assertEqual(dense.kernel_regularizer, None)
        self.assertEqual(dense.bias_regularizer, None)
        self.assertEqual(dense.activity_regularizer, None)
        self.assertEqual(dense.use_bias, True)

        # Test auto-naming
        dense = core_layers.Dense(2, activation=nn_ops.relu)
        dense.apply(random_ops.random_uniform((5, 2)))
        self.assertEqual(dense.name, 'dense_1')
        dense = core_layers.Dense(2, activation=nn_ops.relu)
        dense.apply(random_ops.random_uniform((5, 2)))
        self.assertEqual(dense.name, 'dense_2')

    @test_util.run_deprecated_v1
    def testVariableInput(self):
        with self.cached_session():
            v = variable_scope.get_variable(
                'X', initializer=init_ops.zeros_initializer(), shape=(1, 1))
            x = core_layers.Dense(1)(v)
            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual(x, [[0.0]])

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testCall(self):
        dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
        inputs = random_ops.random_uniform((5, 4), seed=1)
        outputs = dense(inputs)
        self.assertListEqual([5, 2], outputs.get_shape().as_list())
        self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
        self.assertListEqual(dense.trainable_variables,
                             [dense.kernel, dense.bias])
        self.assertListEqual(dense.non_trainable_variables, [])
        if not context.executing_eagerly():
            self.assertEqual(
                len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
        self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
        self.assertEqual(dense.bias.name, 'my_dense/bias:0')

    @test_util.assert_no_new_pyobjects_executing_eagerly
    def testNoEagerLeak(self):
        # Tests that repeatedly constructing and building a Layer does not leak
        # Python objects.
        inputs = random_ops.random_uniform((5, 4), seed=1)
        core_layers.Dense(5)(inputs)
        core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testCallTensorDot(self):
        dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
        inputs = random_ops.random_uniform((5, 4, 3), seed=1)
        outputs = dense(inputs)
        self.assertListEqual([5, 4, 2], outputs.get_shape().as_list())

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testNoBias(self):
        dense = core_layers.Dense(2, use_bias=False, name='my_dense')
        inputs = random_ops.random_uniform((5, 2), seed=1)
        _ = dense(inputs)
        self.assertListEqual(dense.variables, [dense.kernel])
        self.assertListEqual(dense.trainable_variables, [dense.kernel])
        self.assertListEqual(dense.non_trainable_variables, [])
        if not context.executing_eagerly():
            self.assertEqual(
                len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
        self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
        self.assertEqual(dense.bias, None)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testNonTrainable(self):
        dense = core_layers.Dense(2, trainable=False, name='my_dense')
        inputs = random_ops.random_uniform((5, 2), seed=1)
        _ = dense(inputs)
        self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
        self.assertListEqual(dense.non_trainable_variables,
                             [dense.kernel, dense.bias])
        self.assertListEqual(dense.trainable_variables, [])
        if not context.executing_eagerly():
            self.assertEqual(
                len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testOutputShape(self):
        dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense')
        inputs = random_ops.random_uniform((5, 3), seed=1)
        outputs = dense.apply(inputs)
        self.assertEqual(outputs.get_shape().as_list(), [5, 7])

        inputs = random_ops.random_uniform((5, 2, 3), seed=1)
        outputs = dense(inputs)
        self.assertEqual(outputs.get_shape().as_list(), [5, 2, 7])

        inputs = random_ops.random_uniform((1, 2, 4, 3), seed=1)
        outputs = dense.apply(inputs)
        self.assertEqual(outputs.get_shape().as_list(), [1, 2, 4, 7])

    @test_util.run_deprecated_v1
    def testCallOnPlaceHolder(self):
        inputs = array_ops.placeholder(dtype=dtypes.float32)
        dense = core_layers.Dense(4, name='my_dense')
        with self.assertRaises(ValueError):
            dense(inputs)

        inputs = array_ops.placeholder(dtype=dtypes.float32,
                                       shape=[None, None])
        dense = core_layers.Dense(4, name='my_dense')
        with self.assertRaises(ValueError):
            dense(inputs)

        inputs = array_ops.placeholder(dtype=dtypes.float32,
                                       shape=[None, None, None])
        dense = core_layers.Dense(4, name='my_dense')
        with self.assertRaises(ValueError):
            dense(inputs)

        inputs = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3])
        dense = core_layers.Dense(4, name='my_dense')
        dense(inputs)

        inputs = array_ops.placeholder(dtype=dtypes.float32,
                                       shape=[None, None, 3])
        dense = core_layers.Dense(4, name='my_dense')
        dense(inputs)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testActivation(self):
        dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
        inputs = random_ops.random_uniform((5, 3), seed=1)
        outputs = dense(inputs)
        if not context.executing_eagerly():
            self.assertEqual(outputs.op.name, 'dense1/Relu')

        dense = core_layers.Dense(2, name='dense2')
        inputs = random_ops.random_uniform((5, 3), seed=1)
        outputs = dense(inputs)
        if not context.executing_eagerly():
            self.assertEqual(outputs.op.name, 'dense2/BiasAdd')

    @test_util.run_deprecated_v1
    def testActivityRegularizer(self):
        regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
        dense = core_layers.Dense(2,
                                  name='my_dense',
                                  activity_regularizer=regularizer)
        inputs = random_ops.random_uniform((5, 3), seed=1)
        _ = dense(inputs)
        loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
        self.assertEqual(len(loss_keys), 1)
        self.assertListEqual(dense.losses, loss_keys)

    @test_util.run_deprecated_v1
    def testKernelRegularizer(self):
        regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
        dense = core_layers.Dense(2,
                                  name='my_dense',
                                  kernel_regularizer=regularizer)
        inputs = random_ops.random_uniform((5, 3), seed=1)
        _ = dense(inputs)
        loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
        self.assertEqual(len(loss_keys), 1)
        self.evaluate([v.initializer for v in dense.variables])
        self.assertAllEqual(self.evaluate(dense.losses),
                            self.evaluate(loss_keys))

    @test_util.run_deprecated_v1
    def testKernelRegularizerWithReuse(self):
        regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
        inputs = random_ops.random_uniform((5, 3), seed=1)
        _ = core_layers.dense(inputs,
                              2,
                              name='my_dense',
                              kernel_regularizer=regularizer)
        self.assertEqual(
            len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
        _ = core_layers.dense(inputs,
                              2,
                              name='my_dense',
                              kernel_regularizer=regularizer,
                              reuse=True)
        self.assertEqual(
            len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)

    @test_util.run_deprecated_v1
    def testBiasRegularizer(self):
        regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
        dense = core_layers.Dense(2,
                                  name='my_dense',
                                  bias_regularizer=regularizer)
        inputs = random_ops.random_uniform((5, 3), seed=1)
        _ = dense(inputs)
        loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
        self.assertEqual(len(loss_keys), 1)
        self.evaluate([v.initializer for v in dense.variables])
        self.assertAllEqual(self.evaluate(dense.losses),
                            self.evaluate(loss_keys))

    @test_util.run_deprecated_v1
    def testFunctionalDense(self):
        with self.cached_session():
            inputs = random_ops.random_uniform((5, 3), seed=1)
            outputs = core_layers.dense(inputs,
                                        2,
                                        activation=nn_ops.relu,
                                        name='my_dense')
            self.assertEqual(
                len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
            self.assertEqual(outputs.op.name, 'my_dense/Relu')

    @test_util.run_deprecated_v1
    def testFunctionalDenseTwice(self):
        inputs = random_ops.random_uniform((5, 3), seed=1)
        core_layers.dense(inputs, 2)
        vars1 = _get_variable_dict_from_varstore().values()
        core_layers.dense(inputs, 2)
        vars2 = _get_variable_dict_from_varstore().values()
        self.assertEqual(len(vars1), 2)
        self.assertEqual(len(vars2), 4)

    # TODO(alive): get this to  work in eager mode.
    def testFunctionalDenseTwiceReuse(self):
        with self.cached_session():
            inputs = random_ops.random_uniform((5, 3), seed=1)
            core_layers.dense(inputs, 2, name='my_dense')
            vars1 = variables.trainable_variables()
            core_layers.dense(inputs, 2, name='my_dense', reuse=True)
            vars2 = variables.trainable_variables()
            self.assertEqual(vars1, vars2)

    # TODO(alive): get this to  work in eager mode.
    def testFunctionalDenseTwiceReuseFromScope(self):
        with self.cached_session():
            with variable_scope.variable_scope('scope'):
                inputs = random_ops.random_uniform((5, 3), seed=1)
                core_layers.dense(inputs, 2, name='my_dense')
                vars1 = variables.trainable_variables()
            with variable_scope.variable_scope('scope', reuse=True):
                core_layers.dense(inputs, 2, name='my_dense')
                vars2 = variables.trainable_variables()
            self.assertEqual(vars1, vars2)

    @test_util.run_deprecated_v1
    def testFunctionalDenseInitializerFromScope(self):
        with variable_scope.variable_scope(
                'scope', initializer=init_ops.ones_initializer(
                )), self.cached_session():
            inputs = random_ops.random_uniform((5, 3), seed=1)
            core_layers.dense(inputs, 2)
            self.evaluate(variables.global_variables_initializer())
            weights = _get_variable_dict_from_varstore()
            self.assertEqual(len(weights), 2)
            # Check that the matrix weights got initialized to ones (from scope).
            self.assertAllClose(weights['scope/dense/kernel'].read_value(),
                                np.ones((3, 2)))
            # Check that the bias still got initialized to zeros.
            self.assertAllClose(weights['scope/dense/bias'].read_value(),
                                np.zeros((2)))

    def testEagerExecution(self):
        with context.eager_mode():
            container = variable_scope.EagerVariableStore()
            x = constant_op.constant([[2.0]])
            with container.as_default():
                y = core_layers.dense(
                    x,
                    1,
                    name='my_dense',
                    kernel_initializer=init_ops.ones_initializer())
            self.assertAllEqual(y, [[2.0]])
            self.assertEqual(len(container.variables()), 2)
            # Recreate the layer to test reuse.
            with container.as_default():
                core_layers.dense(
                    x,
                    1,
                    name='my_dense',
                    kernel_initializer=init_ops.ones_initializer())
            self.assertEqual(len(container.variables()), 2)

    def testFunctionalDenseWithCustomGetter(self):
        called = [0]

        def custom_getter(getter, *args, **kwargs):
            called[0] += 1
            return getter(*args, **kwargs)

        with variable_scope.variable_scope('test',
                                           custom_getter=custom_getter):
            inputs = random_ops.random_uniform((5, 3), seed=1)
            core_layers.dense(inputs, 2)
        self.assertEqual(called[0], 2)

    @test_util.run_deprecated_v1
    def testFunctionalDenseInScope(self):
        with self.cached_session():
            with variable_scope.variable_scope('test'):
                inputs = random_ops.random_uniform((5, 3), seed=1)
                core_layers.dense(inputs, 2, name='my_dense')
                var_dict = _get_variable_dict_from_varstore()
                var_key = 'test/my_dense/kernel'
                self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
            with variable_scope.variable_scope('test1') as scope:
                inputs = random_ops.random_uniform((5, 3), seed=1)
                core_layers.dense(inputs, 2, name=scope)
                var_dict = _get_variable_dict_from_varstore()
                var_key = 'test1/kernel'
                self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
            with variable_scope.variable_scope('test2'):
                inputs = random_ops.random_uniform((5, 3), seed=1)
                core_layers.dense(inputs, 2)
                var_dict = _get_variable_dict_from_varstore()
                var_key = 'test2/dense/kernel'
                self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testComputeOutputShape(self):
        dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
        ts = tensor_shape.TensorShape
        # pylint: disable=protected-access
        with self.assertRaises(ValueError):
            dense.compute_output_shape(ts(None))
        with self.assertRaises(ValueError):
            dense.compute_output_shape(ts([]))
        with self.assertRaises(ValueError):
            dense.compute_output_shape(ts([1]))
        self.assertEqual([None, 2],
                         dense.compute_output_shape((None, 3)).as_list())
        self.assertEqual([None, 2],
                         dense.compute_output_shape(ts([None, 3])).as_list())
        self.assertEqual([None, 4, 2],
                         dense.compute_output_shape(ts([None, 4,
                                                        3])).as_list())
        # pylint: enable=protected-access

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testConstraints(self):
        k_constraint = lambda x: x / math_ops.reduce_sum(x)
        b_constraint = lambda x: x / math_ops.reduce_max(x)
        dense = core_layers.Dense(2,
                                  kernel_constraint=k_constraint,
                                  bias_constraint=b_constraint)
        inputs = random_ops.random_uniform((5, 3), seed=1)
        dense(inputs)
        self.assertEqual(dense.kernel_constraint, k_constraint)
        self.assertEqual(dense.bias_constraint, b_constraint)
class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
    def _run_if_in_graph_mode(self, val):
        # Running only in graph mode is useful, because optimizers sometimes return
        # a value that, in Graph mode, is runnable with self.evaluate. But in Eager
        # mode, the optimizer already does the computations and the return value
        # cannot be run.
        if not context.executing_eagerly():
            self.evaluate(val)

    def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
        grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
            expected_grad)
        loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
        return lambda: opt.minimize(loss, var_list=[var])

    @parameterized.named_parameters(*TESTCASES)
    def testFixedLossScaleAppliedToLossWithMinimize(self, strategy_fn):
        with strategy_fn().scope() as strategy:
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(2.0)
            loss_scale = 10.
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, dynamic=False, initial_scale=loss_scale)
            self.assertEqual(self.evaluate(opt.loss_scale), loss_scale)
            self.assertIsInstance(opt.loss_scale, ops.Tensor)
            # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale
            # / strategy.num_replicas_in_sync will not be exact, which could lead to
            # assertion failures due to rounding issues.
            self.assertEqual(loss_scale % strategy.num_replicas_in_sync, 0)
            run_fn = self._run_fn_with_grad_check(
                strategy, var, opt, loss_scale / strategy.num_replicas_in_sync)
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The loss is the identity of the variable. Therefore the gradient is 1,
            # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
            self.assertAllClose([3.], self.evaluate(var))

    def testFixedLossScaleAppliedToLossWithGetGradients(self):
        with ops.Graph().as_default():
            var = variables.Variable([2.0])
            opt = gradient_descent.SGD(1.0)
            loss_scale = 10.
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, dynamic=False, initial_scale=loss_scale)
            grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
                loss_scale)
            loss = grad_check_fn(var)
            run_op = opt.get_gradients(loss, [var])
            self.evaluate(variables.global_variables_initializer())
            # This will cause an assertion to run, as
            # mp_test_util.create_identity_with_grad_check_fn added an assertion op.
            self.evaluate(run_op)

    def testDynamicAttrsWithFixedLossScale(self):
        opt = gradient_descent.SGD()
        opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      dynamic=False,
                                                      initial_scale=2.)
        self.assertFalse(opt.dynamic)
        self.assertIsNone(opt.dynamic_counter)
        self.assertIsNone(opt.dynamic_growth_steps)

    def testGetScaledLoss(self):
        opt = gradient_descent.SGD(2.0)
        opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      dynamic=False,
                                                      initial_scale=2.)
        loss = ops.convert_to_tensor_v2_with_dispatch(5.)
        self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
        self.assertEqual(10.,
                         self.evaluate(opt.get_scaled_loss(lambda: loss)()))
        loss = ops.convert_to_tensor_v2_with_dispatch(5., dtype='float16')
        self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
        self.assertEqual(10.,
                         self.evaluate(opt.get_scaled_loss(lambda: loss)()))

    def testGetUnscaledGradients(self):
        opt = gradient_descent.SGD(2.0)
        opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      dynamic=False,
                                                      initial_scale=2)
        scaled_grads = [
            ops.convert_to_tensor_v2_with_dispatch(3.), None,
            ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16')
        ]
        grads = opt.get_unscaled_gradients(scaled_grads)
        grads = [self.evaluate(g) if g is not None else g for g in grads]
        self.assertEqual([1.5, None, -2.], grads)

    def testGetUnscaledSparseGradients(self):
        opt = gradient_descent.SGD(2.0)
        opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      dynamic=False,
                                                      initial_scale=2)
        sparse_scaled_grad = ops.IndexedSlices(
            ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]),
            ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'),
            dense_shape=ops.convert_to_tensor_v2_with_dispatch([5, 2],
                                                               dtype='int32'))
        sparse_grad = opt.get_unscaled_gradients([sparse_scaled_grad])[0]
        self.assertIsInstance(sparse_grad, ops.IndexedSlices)
        self.assertAllEqual([[2., 1.], [4., 2.5]],
                            self.evaluate(sparse_grad.values))

    @parameterized.named_parameters(*TESTCASES)
    def testDynamicLossScale(self, strategy_fn):
        strategy = strategy_fn()
        learning_rate = 2.
        expected_gradient = variables.Variable(learning_rate /
                                               strategy.num_replicas_in_sync)
        with strategy.scope():
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2, dynamic_growth_steps=1)
            self.assertEqual(opt.initial_scale, 2.)
            self.assertIsInstance(opt.initial_scale, float)
            self.assertEqual(opt.dynamic_growth_steps, 1)
            self.assertIsInstance(opt.dynamic_growth_steps, int)

            self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync,
                             0)
            run_fn = self._run_fn_with_grad_check(strategy, var, opt,
                                                  expected_gradient)
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The loss is the identity of the variable. Therefore the gradient is 1,
            # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
            self.assertAllClose([3.], self.evaluate(var))

            # Loss scale will be double, so the expected gradient is also doubled.
            self.evaluate(
                expected_gradient.assign(2 * learning_rate /
                                         strategy.num_replicas_in_sync))
            run_op = strategy.experimental_run(run_fn)
            self._run_if_in_graph_mode(run_op)
            # As before, the 2 is subtracted from the variable, making it's new value
            # 1.
            self.assertAllClose([1.], self.evaluate(var))

    def testDynamicLossScaleDefaultValues(self):
        opt = gradient_descent.SGD()
        opt = loss_scale_optimizer.LossScaleOptimizer(opt)
        self.assertEqual(opt.initial_scale, 2**15)
        self.assertEqual(opt.dynamic_growth_steps, 2000)
        self.evaluate(variables.global_variables_initializer())
        self.assertEqual(self.evaluate(opt.loss_scale), 2**15)

    # pylint: disable=cell-var-from-loop
    @parameterized.named_parameters(*TESTCASES)
    def testClipping(self, strategy_fn):
        strategy = strategy_fn()
        learning_rate = 2.
        for clip_type in ('clipnorm', 'global_clipnorm', 'clipvalue'):
            with strategy.scope(), self.subTest(clip_type=clip_type):
                var = variables.Variable([5.0])
                opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0})
                opt = loss_scale_optimizer.LossScaleOptimizer(
                    opt, initial_scale=2, dynamic_growth_steps=1)
                self.assertEqual(getattr(opt, clip_type), 2.0)
                self.assertEqual(
                    opt.initial_scale % strategy.num_replicas_in_sync, 0)

                loss = lambda: var * 4 / strategy.num_replicas_in_sync
                run_fn = lambda: opt.minimize(loss, var_list=[var])

                # Test running with clipped gradients
                run_op = strategy.experimental_run(run_fn)
                self.evaluate(variables.global_variables_initializer())
                self._run_if_in_graph_mode(run_op)
                # The gradient is 4 but is clipped to 2, so the variable will be
                # init_val - clipped_grad * lr == 5 - 2 * 2 == 1
                self.assertAllClose([1.], self.evaluate(var))
                self.assertEqual(self.evaluate(opt.loss_scale), 4)

                # Test changing the clip amount and running again
                setattr(opt, clip_type, 3.0)
                run_op = strategy.experimental_run(run_fn)
                self._run_if_in_graph_mode(run_op)
                # The gradient is 4 but is clipped to 3, so the variable will be
                # prev_var - clipped_grad * lr == 1 - 3 * 2 == -5
                self.assertAllClose([-5.], self.evaluate(var))
                self.assertEqual(self.evaluate(opt.loss_scale), 8)

                # Test Inf gradients are still skipped instead of being clipped
                loss = lambda: var * float('Inf')
                run_fn = lambda: opt.minimize(loss, var_list=[var])
                run_op = strategy.experimental_run(run_fn)
                self._run_if_in_graph_mode(run_op)
                self.assertAllClose([-5.],
                                    self.evaluate(var))  # Var does not change
                self.assertEqual(self.evaluate(opt.loss_scale), 4)

    # pylint: enable=cell-var-from-loop

    @parameterized.named_parameters(*TESTCASES)
    def testDynamicUpdate(self, strategy_fn):
        with strategy_fn().scope() as strategy:
            var = variables.Variable([1.0, 2.0])
            opt = gradient_descent.SGD(1.0)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2, dynamic_growth_steps=1)

            # Test optimizer with finite gradients
            loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
            run_fn = lambda: opt.minimize(loss, var_list=[var])
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # Gradient is 2, so variable will have 2 subtracted from it
            self.assertAllClose([-1.0, 0.0], self.evaluate(var))
            # Loss scale has doubled from 2 to 4
            self.assertEqual(4., self.evaluate(opt.loss_scale))

            # Test optimizer with NaN gradients
            loss = lambda: var * float('NaN')
            run_fn = lambda: opt.minimize(loss, var_list=[var])
            run_op = strategy.experimental_run(run_fn)
            self._run_if_in_graph_mode(run_op)
            # Variable should not change from before, due to NaN gradients.
            self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
            # Loss scale should half due to NaN gradients.
            self.assertEqual(2., self.evaluate(opt.loss_scale))

    @parameterized.named_parameters(*TESTCASES)
    def testDynamicLossScaleWithFloat16Loss(self, strategy_fn):
        strategy = strategy_fn()
        learning_rate = 2.
        with strategy.scope():
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2, dynamic_growth_steps=1)

            def loss():
                return math_ops.cast(var / strategy.num_replicas_in_sync,
                                     'float16')

            run_fn = lambda: opt.minimize(loss, var_list=[var])
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The loss is the identity of the variable. Therefore the gradient is 1,
            # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
            self.assertAllClose([3.], self.evaluate(var))

    @parameterized.named_parameters(*TESTCASES)
    def testDynamicLossScaleWithSlots(self, strategy_fn):
        strategy_obj = strategy_fn()
        if (isinstance(strategy_obj, mirrored_strategy.MirroredStrategy)
                and control_flow_v2_toggles.control_flow_v2_enabled()
                and not context.executing_eagerly()):
            self.skipTest('b/138667997')
        with strategy_obj.scope() as strategy:
            var = variables.Variable([1.0, 2.0])
            # An SGD optimizer with momentum has slot variables.
            opt = gradient_descent.SGD(1.0, momentum=1.)
            initial_scale = 2.
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=initial_scale, dynamic_growth_steps=1)
            loss = lambda: var / strategy.num_replicas_in_sync
            run_fn = lambda: opt.minimize(loss, var_list=[var])
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The momentum accumulator starts at 0 and the gradient is 1. The
            # accumulator is incremented by the gradient, so it is now 1. Then the
            # variable is subtracted by the accumulator, so the variable is subtracted
            # by 1.
            self.assertAllClose([0.0, 1.0], self.evaluate(var))
            self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 2)

            run_op = strategy.experimental_run(run_fn)
            self._run_if_in_graph_mode(run_op)
            # The momentum accumulator was 1 before this step and the gradient is 1.
            # The accumulator is incremented by the gradient, so it is now 2. Then the
            # variable is subtracted by the accumulator, so the variable is subtracted
            # by 2.
            self.assertAllClose([-2., -1.], self.evaluate(var))
            self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 4)

            self.assertEqual(opt.get_slot_names(), ['momentum'])

    def testIterations(self):
        opt = gradient_descent.SGD(2.0)
        lso = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      dynamic=False,
                                                      initial_scale=10.)
        lso.iterations = 7
        self.assertEqual(lso.iterations, 7)
        self.assertEqual(opt.iterations, 7)

    @parameterized.named_parameters(*TESTCASES)
    def testIterationsIncremented(self, strategy_fn):
        with strategy_fn().scope() as strategy:
            # Test iterations is incremented in opt.minimize.
            opt = gradient_descent.SGD(1.0)
            opt = loss_scale_optimizer.LossScaleOptimizer(opt)
            var = variables.Variable([5.0])
            loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
            run_fn = lambda: opt.minimize(loss, [var])
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            self.assertEqual(self.evaluate(var),
                             3.0)  # Grad is 2, so var is 5 - 2
            self.assertEqual(self.evaluate(opt.iterations), 1)

            # Test iterations is incremented in opt.minimize even if gradients aren't
            # applied to variables due to NaN gradients.
            loss = lambda: var * float('NaN')
            run_fn = lambda: opt.minimize(loss, [var])
            run_op = strategy.experimental_run(run_fn)
            self._run_if_in_graph_mode(run_op)
            self.assertEqual(self.evaluate(var), 3.0)
            self.assertEqual(self.evaluate(opt.iterations), 2)

    def testWeightMethods(self):
        with self.test_session():
            var = variables.Variable([1.0])
            opt = gradient_descent.SGD(1.0)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2., dynamic_growth_steps=1)
            run_op = opt.minimize(lambda: var * 2, [var])
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)

            self.assertLen(opt.weights, 1)  # The 'iterations' weight
            self.assertEqual(self.evaluate(opt.weights[0]), 1)
            self.assertEqual(opt.get_weights()[0], 1)
            self.assertEqual(self.evaluate(opt.variables()[0]), 1)
            opt.set_weights([np.array(2.)])
            self.assertEqual(self.evaluate(opt.variables()[0]), 2)

    def testHyperParametersExposed(self):
        with self.cached_session():
            opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9)
            lso = loss_scale_optimizer.LossScaleOptimizer(opt)
            # Force hyperparameters to be created
            opt.lr  # pylint: disable=pointless-statement
            self.evaluate(variables.global_variables_initializer())

            self.assertEqual(self.evaluate(lso.beta_1), 0.5)
            self.assertIsInstance(lso.beta_1, variables.Variable)
            self.assertEqual(self.evaluate(lso.lr), 1.0)
            self.assertIs(lso.lr, opt.lr)
            self.assertIs(lso.lr, lso.learning_rate)

            lso.beta_1 = 0.25
            self.assertEqual(self.evaluate(lso.beta_1), 0.25)
            self.assertEqual(self.evaluate(opt.beta_1), 0.25)
            self.assertIs(lso.beta_1, opt.beta_1)
            opt.beta_1 = 0.75
            self.assertEqual(self.evaluate(lso.beta_1), 0.75)
            self.assertEqual(self.evaluate(opt.beta_1), 0.75)
            self.assertIs(lso.beta_1, opt.beta_1)
            lso.lr = 2.0
            self.assertEqual(self.evaluate(lso.lr), 2.0)
            self.assertEqual(self.evaluate(lso.learning_rate), 2.0)
            self.assertEqual(self.evaluate(opt.lr), 2.0)
            self.assertEqual(self.evaluate(opt.learning_rate), 2.0)
            self.assertIs(lso.lr, opt.lr)

            # Test setting attribute that is both attribute on LossScaleOptimizer and
            # hyperparameter on wrapped optimizer.
            class MyOpt(gradient_descent.SGD):
                def __init__(self):
                    super().__init__()
                    self._set_hyper('loss_scale', 123.)

            opt = MyOpt()
            lso = loss_scale_optimizer.LossScaleOptimizer(opt)
            with self.assertRaises(AttributeError):
                lso.loss_scale = 2.

    def testArbitraryAttributesNotExposed(self):
        opt = gradient_descent.SGD()
        lso = loss_scale_optimizer.LossScaleOptimizer(opt)
        self.assertFalse(opt.nesterov)
        with self.assertRaisesRegex(
                AttributeError,
                "'LossScaleOptimizer' object has no attribute 'nesterov'"):
            lso.nesterov  # pylint: disable=pointless-statement

        lso.nesterov = True
        self.assertTrue(lso.nesterov)
        self.assertFalse(opt.nesterov)

    def testDir(self):
        lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
        dir_result = dir(lso)
        self.assertIn('learning_rate', dir_result)  # Hyperparameter
        self.assertIn('lr', dir_result)  # Hyperparameter
        self.assertIn('minimize', dir_result)  # Attribute
        self.assertIn('loss_scale', dir_result)  # Attribute
        self.assertNotIn('nesterov',
                         dir_result)  # Attribute on inner optimizer
        self.assertIn('nesterov', dir(lso.inner_optimizer))

    def testApplyGradientsGetsUnwrappedTensors(self):
        # Tests that gradients passed to apply_gradients are not wrapped in a
        # DistributionStrategy wrapper, such as PerReplica, but instead are raw
        # Tensors. Optimizer subclasses that override apply_gradients() expect raw
        # Tensors, even though the base Optimizer can handle PerReplica gradients.

        outer_self = self

        class MyOptimizer(gradient_descent.SGD):
            def apply_gradients(self,
                                grads_and_vars,
                                name=None,
                                experimental_aggregate_gradients=True):
                for grad, _ in grads_and_vars:
                    outer_self.assertIsInstance(grad, ops.Tensor)
                return super(MyOptimizer, self).apply_gradients(
                    grads_and_vars, name, experimental_aggregate_gradients)

        with create_mirrored_strategy().scope() as strategy:
            var = variables.Variable([5.0])
            opt = MyOptimizer(learning_rate=1.0)
            opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                          dynamic=False,
                                                          initial_scale=1)
            loss = lambda: var * 2.0
            run_fn = lambda: opt.minimize(loss, [var])
            strategy.experimental_run(run_fn)

    @parameterized.named_parameters(*TESTCASES)
    def testV1Optimizer(self, strategy_fn):
        strategy = strategy_fn()
        learning_rate = 2.
        with strategy.scope():
            # Test FixedLossScale
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale=2)
            self.assertIsInstance(opt.loss_scale, ops.Tensor)
            self.evaluate(variables.global_variables_initializer())
            self.assertEqual(self.evaluate(opt.loss_scale), 2)
            self.assertEqual(opt.initial_scale, 2)
            self.assertIsNone(opt.dynamic_growth_steps)
            run_fn = self._run_fn_with_grad_check(
                strategy, var, opt, 2 / strategy.num_replicas_in_sync)
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The loss is the identity of the variable. Therefore the gradient is 1,
            # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
            self.assertAllClose([3.], self.evaluate(var))

            # Test DynamicLossScale
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 'dynamic')
            self.assertEqual(opt.initial_scale, 2**15)
            self.assertEqual(opt.dynamic_growth_steps, 2000)
            self.evaluate(variables.global_variables_initializer())
            self.assertEqual(self.evaluate(opt.loss_scale), 2**15)
            for s in strategy.experimental_local_results(opt.dynamic_counter):
                self.assertEqual(self.evaluate(s), 0)

            loss = lambda: var * float('NaN')
            run_fn = lambda: opt.minimize(loss, var_list=[var])
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            self.assertAllClose([5.], self.evaluate(var))
            self.assertEqual(self.evaluate(opt.loss_scale), 2**14)
            for s in strategy.experimental_local_results(opt.dynamic_counter):
                self.assertEqual(self.evaluate(s), 0)

    @parameterized.named_parameters(*TESTCASES)
    def testPassingV1LossScale(self, strategy_fn):
        strategy = strategy_fn()
        learning_rate = 2.
        with strategy.scope():
            # Test FixedLossScale
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            loss_scale = tf_loss_scale_module.FixedLossScale(2.)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
            self.assertIsInstance(opt.loss_scale, ops.Tensor)
            self.evaluate(variables.global_variables_initializer())
            self.assertEqual(self.evaluate(opt.loss_scale), 2)
            run_fn = self._run_fn_with_grad_check(
                strategy, var, opt, 2 / strategy.num_replicas_in_sync)
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            # The loss is the identity of the variable. Therefore the gradient is 1,
            # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
            self.assertAllClose([3.], self.evaluate(var))

            # Test DynamicLossScale
            var = variables.Variable([5.0])
            opt = gradient_descent.SGD(learning_rate)
            loss_scale = tf_loss_scale_module.DynamicLossScale(
                initial_loss_scale=4, increment_period=1, multiplier=2)
            loss_scale._current_loss_scale.assign(2)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
            self.assertEqual(opt.initial_scale, 4)
            self.assertEqual(opt.dynamic_growth_steps, 1)
            self.evaluate(variables.global_variables_initializer())
            # Current loss scale is not copied so loss scale is reinitialized to 4
            self.assertEqual(self.evaluate(opt.loss_scale), 4)
            for s in strategy.experimental_local_results(opt.dynamic_counter):
                self.assertEqual(self.evaluate(s), 0)

            run_fn = self._run_fn_with_grad_check(
                strategy, var, opt, 4 / strategy.num_replicas_in_sync)
            run_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self._run_if_in_graph_mode(run_op)
            self.assertAllClose([3.], self.evaluate(var))

    def testPassingV1LossScaleErrors(self):
        opt = gradient_descent.SGD()
        loss_scale = tf_loss_scale_module.DynamicLossScale(multiplier=4)
        with self.assertRaisesRegex(
                ValueError, 'When passing a DynamicLossScale to "loss_scale", '
                'DynamicLossScale.multiplier must be 2. Got: '
                'DynamicLossScale'):
            loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)

        class MyLossScale(tf_loss_scale_module.LossScale):
            def __call__(self):
                return 1.

            def update(self, grads):
                return None, True

            def get_config(self):
                return {}

        with self.assertRaisesRegex(
                TypeError,
                'Passing a LossScale that is not a FixedLossScale or a '
                'DynamicLossScale is no longer supported. Got:'):
            loss_scale_optimizer.LossScaleOptimizerV1(opt, MyLossScale())

    @parameterized.named_parameters(
        {
            'testcase_name': 'SaveAndRestoreBase',
            'strategy_fn': default_strategy_fn,
            'save_with_ls': True,
            'restore_with_ls': True,
        }, {
            'testcase_name': 'SaveAndRestoreDistribute',
            'strategy_fn': create_mirrored_strategy,
            'save_with_ls': True,
            'restore_with_ls': True,
        }, {
            'testcase_name': 'SaveBase',
            'strategy_fn': default_strategy_fn,
            'save_with_ls': True,
            'restore_with_ls': False,
        }, {
            'testcase_name': 'SaveDistribute',
            'strategy_fn': create_mirrored_strategy,
            'save_with_ls': True,
            'restore_with_ls': False,
        }, {
            'testcase_name': 'RestoreBase',
            'strategy_fn': default_strategy_fn,
            'save_with_ls': False,
            'restore_with_ls': True,
        }, {
            'testcase_name': 'RestoreDistribute',
            'strategy_fn': create_mirrored_strategy,
            'save_with_ls': False,
            'restore_with_ls': True,
        })
    def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
        class MySGD(gradient_descent.SGD):
            """A custom optimizer that tracks an extra variable."""
            def __init__(self, *args, **kwargs):
                super(MySGD, self).__init__(*args, **kwargs)
                self.my_var = variables.Variable(0.)
                self._track_trackable(self.my_var, 'my_var')

        strategy = strategy_fn()
        replicas = strategy.num_replicas_in_sync
        if (isinstance(strategy, mirrored_strategy.MirroredStrategy)
                and not context.executing_eagerly()):
            # TODO(b/121381184): Enable running the test in this case.
            return

        with self.test_session(), strategy.scope():
            # Build and run a simple model.
            var = variables.Variable([2.0])
            opt = inner_opt = MySGD(1., momentum=1.)
            if save_with_ls:
                opt = loss_scale_optimizer.LossScaleOptimizer(
                    opt, initial_scale=1., dynamic_growth_steps=2.)
            run_fn = lambda: opt.minimize(lambda: var / replicas + 1.,
                                          var_list=[var])
            opt_op = strategy.experimental_run(run_fn)
            self.evaluate(variables.global_variables_initializer())
            self.evaluate(strategy.experimental_local_results(opt_op))

            # Assert values.
            self.assertEqual(self.evaluate(var), 1.)
            if save_with_ls:
                self.assertEqual(self.evaluate(opt.loss_scale), 1.)
                self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
            slot_var = opt.get_slot(var, 'momentum')
            self.assertEqual(self.evaluate(slot_var).item(), -1)
            self.assertEqual(self.evaluate(opt.iterations), 1)

            # Set optimizer variable to check arbitrary optimizer attributes can be
            # saved/restored
            self.evaluate(inner_opt.my_var.assign(1.))

            # Save a checkpoint.
            checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
            prefix = os.path.join(self.get_temp_dir(), 'ckpt')
            save_path = checkpoint.save(prefix)

            # Create new model
            var = variables.Variable([2.0])
            opt = inner_opt = MySGD(1., momentum=1.)
            if restore_with_ls:
                opt = loss_scale_optimizer.LossScaleOptimizer(
                    opt, initial_scale=1., dynamic_growth_steps=2.)

            # Restore new model.
            checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
            status = checkpoint.restore(save_path)
            if save_with_ls:
                status.assert_existing_objects_matched()
            else:
                status.assert_nontrivial_match()

            # Assert restored values. We can only assert in eager mode since the
            # variables are uninitialized in graph mode
            if context.executing_eagerly():
                self.assertEqual(self.evaluate(var), 1.)
                if save_with_ls and restore_with_ls:
                    self.assertEqual(self.evaluate(opt.loss_scale), 1.)
                    self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
                elif restore_with_ls:
                    self.assertEqual(self.evaluate(opt.loss_scale), 1.)
                    self.assertEqual(self.evaluate(opt.dynamic_counter), 0)
                self.assertEqual(self.evaluate(opt.iterations), 1)

            # Run the model again.
            run_fn = lambda: opt.minimize(lambda: var / replicas + 1.,
                                          var_list=[var])
            opt_op = strategy.experimental_run(run_fn)

            # Assert new values.
            self.evaluate(variables.global_variables_initializer())
            status.run_restore_ops()
            self.evaluate(strategy.experimental_local_results(opt_op))
            self.assertEqual(self.evaluate(var), -1)
            slot_var = opt.get_slot(var, 'momentum')
            self.assertEqual(self.evaluate(slot_var).item(), -2)
            self.assertEqual(self.evaluate(opt.iterations), 2)
            self.assertEqual(self.evaluate(inner_opt.my_var), 1)

            # Restore model again to test restoring after slots are created
            status = checkpoint.restore(save_path)
            if save_with_ls and restore_with_ls:
                status.assert_consumed()
            elif save_with_ls:
                status.assert_existing_objects_matched()
            elif restore_with_ls:
                status.assert_nontrivial_match()
            status.run_restore_ops()
            self.assertEqual(self.evaluate(var), 1)
            self.assertEqual(self.evaluate(slot_var).item(), -1)

    @combinations.generate(
        combinations.combine(get_config=['v1', 'v2', 'tf2_3'],
                             from_config=['v1', 'v2']))
    def testGetConfigFixed(self, get_config, from_config):
        # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the
        # LossScaleOptimizer from TF 2.3. Then restore the config into a
        # LossScaleOptimizerV1 or LossScaleOptimizer
        opt = gradient_descent.SGD(2., momentum=0.5)
        if get_config == 'v1':
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 2)
            config = opt.get_config()
        elif get_config == 'v2':
            opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                          dynamic=False,
                                                          initial_scale=2)
            config = opt.get_config()
        else:
            self.assertEqual(get_config, 'tf2_3')
            config = {
                'optimizer': {
                    'class_name': 'SGD',
                    'config': {
                        'learning_rate': 2.0,
                        'momentum': 0.5,
                        'decay': 0.0,
                        'nesterov': False,
                        'name': 'SGD',
                    }
                },
                'loss_scale': {
                    'class_name': 'FixedLossScale',
                    'config': {
                        'loss_scale_value': 2.0
                    }
                },
            }

        if from_config == 'v1':
            opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)
        else:
            self.assertEqual(from_config, 'v2')
            opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config)

        # Force hyperparameters to be created
        opt.lr  # pylint: disable=pointless-statement
        self.evaluate(variables.global_variables_initializer())

        # Test attributes on the optimizer
        self.assertEqual(self.evaluate(opt.lr), 2.)
        self.assertEqual(self.evaluate(opt.inner_optimizer.lr), 2.)
        self.assertEqual(self.evaluate(opt.momentum), 0.5)
        self.assertEqual(self.evaluate(opt.loss_scale), 2.)
        self.assertEqual(opt.initial_scale, 2.)
        self.assertIsNone(opt.dynamic_growth_steps)
        self.assertIsNone(opt.dynamic_counter)
        self.assertFalse(opt.dynamic)

        # Ensure the optimizer can be used
        var = variables.Variable([5.0])
        run_op = self._run_fn_with_grad_check(
            distribution_strategy_context.get_strategy(), var, opt, 2)()
        self.evaluate(variables.global_variables_initializer())
        self._run_if_in_graph_mode(run_op)
        self.assertEqual(self.evaluate(var), [3.])

    @combinations.generate(
        combinations.combine(get_config=['v1', 'v2', 'tf2_3'],
                             from_config=['v1', 'v2']))
    def testGetConfigDynamic(self, get_config, from_config):
        # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the
        # LossScaleOptimizer from TF 2.3. Then restore the config into a
        # LossScaleOptimizerV1 or LossScaleOptimizer
        opt = gradient_descent.SGD(2., momentum=0.5)
        if get_config == 'v1':
            loss_scale = tf_loss_scale_module.DynamicLossScale(
                initial_loss_scale=2, increment_period=3)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
            config = opt.get_config()
        elif get_config == 'v2':
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2, dynamic_growth_steps=3)
            config = opt.get_config()
        else:
            self.assertEqual(get_config, 'tf2_3')
            config = {
                'optimizer': {
                    'class_name': 'SGD',
                    'config': {
                        'learning_rate': 2.0,
                        'momentum': 0.5,
                        'decay': 0.0,
                        'nesterov': False,
                        'name': 'SGD',
                    }
                },
                'loss_scale': {
                    'class_name': 'DynamicLossScale',
                    'config': {
                        'initial_loss_scale': 2.0,
                        'increment_period': 3,
                        'multiplier': 2.0,
                    }
                },
            }

        if from_config == 'v1':
            opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)
        else:
            self.assertEqual(from_config, 'v2')
            opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config)

        # Force hyperparameters to be created
        opt.lr  # pylint: disable=pointless-statement
        self.evaluate(variables.global_variables_initializer())

        # Test attributes on the optimizer
        self.assertEqual(self.evaluate(opt.lr), 2.)
        self.assertEqual(self.evaluate(opt.inner_optimizer.lr), 2.)
        self.assertEqual(self.evaluate(opt.momentum), 0.5)
        self.assertEqual(self.evaluate(opt.loss_scale), 2.)
        self.assertEqual(opt.initial_scale, 2.)
        self.assertEqual(opt.dynamic_growth_steps, 3.)
        self.assertTrue(opt.dynamic)

        # Ensure the optimizer can be used
        var = variables.Variable([5.0])
        run_op = self._run_fn_with_grad_check(
            distribution_strategy_context.get_strategy(), var, opt, 2)()
        self.evaluate(variables.global_variables_initializer())
        self._run_if_in_graph_mode(run_op)
        self.assertEqual(self.evaluate(var), [3.])
        self.assertEqual(self.evaluate(opt.dynamic_counter), 1)

    def test_from_config_with_invalid_multiplier(self):
        config = {
            'optimizer': {
                'class_name': 'SGD',
                'config': {
                    'learning_rate': 2.0,
                    'momentum': 0.5,
                    'decay': 0.0,
                    'nesterov': False,
                    'name': 'SGD',
                }
            },
            'loss_scale': {
                'class_name': 'DynamicLossScale',
                'config': {
                    'initial_loss_scale': 2.0,
                    'increment_period': 3,
                    'multiplier': 4.0,
                }
            },
        }

        expected_error = ('Cannot deserialize LossScaleOptimizer with a '
                          'DynamicLossScale whose multiplier is not 2. Got '
                          'DynamicLossScale: DynamicLossScale\\(')
        with self.assertRaisesRegex(ValueError, expected_error):
            loss_scale_optimizer.LossScaleOptimizer.from_config(config)
        with self.assertRaisesRegex(ValueError, expected_error):
            loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)

    @parameterized.named_parameters(
        {
            'testcase_name': 'V2',
            'use_v1': False,
        },
        {
            'testcase_name': 'V1',
            'use_v1': True,
        },
    )
    def testSerializationWithBuiltInOptimizer(self, use_v1):
        opt = gradient_descent.SGD(2., momentum=0.5)
        if use_v1:
            loss_scale = tf_loss_scale_module.DynamicLossScale(
                initial_loss_scale=2., increment_period=3.)
            opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
        else:
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=2., dynamic_growth_steps=3.)
        config = optimizers.serialize(opt)
        opt = optimizers.deserialize(config)
        # Force hyperparameters to be created
        opt.lr  # pylint: disable=pointless-statement
        self.evaluate(variables.global_variables_initializer())

        self.assertEqual(self.evaluate(opt.lr), 2.)
        self.assertEqual(self.evaluate(opt.inner_optimizer.momentum), 0.5)
        self.assertEqual(self.evaluate(opt.loss_scale), 2.)
        self.assertEqual(opt.dynamic_growth_steps, 3.)
        self.assertTrue(opt.dynamic, 4.)
        # Deserializing a LossScaleOptimizer always always results in a V2
        # LossScaleOptimizer, even if serialized with a LossScaleOptimizerV1.
        self.assertAllEqual(type(opt), loss_scale_optimizer.LossScaleOptimizer)

        # Ensure the optimizer can be used
        var = variables.Variable([5.0])
        run_op = self._run_fn_with_grad_check(
            distribution_strategy_context.get_strategy(), var, opt, 2)()
        self.evaluate(variables.global_variables_initializer())
        self._run_if_in_graph_mode(run_op)
        self.assertEqual(self.evaluate(var), [3.])
        self.assertEqual(self.evaluate(opt.dynamic_counter), 1)

    def testSerializationWithCustomOptimizer(self):
        class MySGD(gradient_descent.SGD):
            def __init__(self, *args, **kwargs):
                super(MySGD, self).__init__(*args, **kwargs)
                self.my_attribute = 123

        opt = MySGD(2., momentum=0.5)
        opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                      initial_scale=2.,
                                                      dynamic_growth_steps=3.)
        config = optimizers.serialize(opt)
        custom_objects = {'MySGD': MySGD}
        opt = optimizers.deserialize(config, custom_objects=custom_objects)
        # Force hyperparameters to be created
        opt.lr  # pylint: disable=pointless-statement
        self.evaluate(variables.global_variables_initializer())

        self.assertEqual(self.evaluate(opt.lr), 2.)
        self.assertEqual(self.evaluate(opt.inner_optimizer.momentum), 0.5)
        self.assertEqual(self.evaluate(opt.loss_scale), 2.)
        self.assertEqual(opt.dynamic_growth_steps, 3.)
        self.assertEqual(opt.inner_optimizer.my_attribute, 123)

    def testUnsupportedStrategy(self):
        strategy = central_storage_strategy.CentralStorageStrategy()
        expected_error = (
            'Loss scaling is not supported with the tf.distribute.Strategy: '
            'CentralStorageStrategy. Try using a different Strategy, e.g. a '
            'MirroredStrategy')
        with strategy.scope(), self.assertRaisesRegex(ValueError,
                                                      expected_error):
            loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
        opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
        with strategy.scope():
            var = variables.Variable(1.0)
            loss = lambda: var * 2.0
            run_fn = lambda: opt.minimize(loss, [var])
            with self.assertRaisesRegex(ValueError, expected_error):
                strategy.experimental_run(run_fn)

    def testInvalidArgsWithFixedLossScale(self):
        opt = gradient_descent.SGD()
        with self.assertRaisesRegex(
                ValueError,
                '"initial_scale" must be specified if "dynamic" is False'):
            loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False)
        with self.assertRaisesRegex(
                ValueError,
                '"dynamic_growth_steps" must be None if "dynamic" is '
                'False, but got: 2'):
            loss_scale_optimizer.LossScaleOptimizer(opt,
                                                    dynamic=False,
                                                    initial_scale=1,
                                                    dynamic_growth_steps=2)

    def testDynamicMustBeBool(self):
        opt = gradient_descent.SGD()
        with self.assertRaisesRegex(
                TypeError,
                '"dynamic" argument to LossScaleOptimizer.__init__ must be '
                "a bool, but got: 'dynamic'"):
            loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
Beispiel #10
0
class DenseFeaturesTest(keras_parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_retrieving_input(self):
        features = {'a': [0.]}
        dense_features = df.DenseFeatures(fc.numeric_column('a'))
        inputs = self.evaluate(dense_features(features))
        self.assertAllClose([[0.]], inputs)

    def test_reuses_variables(self):
        with context.eager_mode():
            sparse_input = sparse_tensor.SparseTensor(indices=((0, 0), (1, 0),
                                                               (2, 0)),
                                                      values=(0, 1, 2),
                                                      dense_shape=(3, 3))

            # Create feature columns (categorical and embedding).
            categorical_column = fc.categorical_column_with_identity(
                key='a', num_buckets=3)
            embedding_dimension = 2

            def _embedding_column_initializer(shape,
                                              dtype,
                                              partition_info=None):
                del shape  # unused
                del dtype  # unused
                del partition_info  # unused
                embedding_values = (
                    (1, 0),  # id 0
                    (0, 1),  # id 1
                    (1, 1))  # id 2
                return embedding_values

            embedding_column = fc.embedding_column(
                categorical_column,
                dimension=embedding_dimension,
                initializer=_embedding_column_initializer)

            dense_features = df.DenseFeatures([embedding_column])
            features = {'a': sparse_input}

            inputs = dense_features(features)
            variables = dense_features.variables

            # Sanity check: test that the inputs are correct.
            self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)

            # Check that only one variable was created.
            self.assertEqual(1, len(variables))

            # Check that invoking dense_features on the same features does not create
            # additional variables
            _ = dense_features(features)
            self.assertEqual(1, len(variables))
            self.assertIs(variables[0], dense_features.variables[0])

    def test_feature_column_dense_features_gradient(self):
        with context.eager_mode():
            sparse_input = sparse_tensor.SparseTensor(indices=((0, 0), (1, 0),
                                                               (2, 0)),
                                                      values=(0, 1, 2),
                                                      dense_shape=(3, 3))

            # Create feature columns (categorical and embedding).
            categorical_column = fc.categorical_column_with_identity(
                key='a', num_buckets=3)
            embedding_dimension = 2

            def _embedding_column_initializer(shape,
                                              dtype,
                                              partition_info=None):
                del shape  # unused
                del dtype  # unused
                del partition_info  # unused
                embedding_values = (
                    (1, 0),  # id 0
                    (0, 1),  # id 1
                    (1, 1))  # id 2
                return embedding_values

            embedding_column = fc.embedding_column(
                categorical_column,
                dimension=embedding_dimension,
                initializer=_embedding_column_initializer)

            dense_features = df.DenseFeatures([embedding_column])
            features = {'a': sparse_input}

            def scale_matrix():
                matrix = dense_features(features)
                return 2 * matrix

            # Sanity check: Verify that scale_matrix returns the correct output.
            self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())

            # Check that the returned gradient is correct.
            grad_function = backprop.implicit_grad(scale_matrix)
            grads_and_vars = grad_function()
            indexed_slice = grads_and_vars[0][0]
            gradient = grads_and_vars[0][0].values

            self.assertAllEqual([0, 1, 2], indexed_slice.indices)
            self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)

    def test_dense_feature_with_training_arg(self):
        price1 = fc.numeric_column('price1', shape=2)
        price2 = fc.numeric_column('price2')

        # Monkey patch the second numeric column to simulate a column that has
        # different behavior by mode.
        def training_aware_get_dense_tensor(transformation_cache,
                                            state_manager,
                                            training=None):
            return transformation_cache.get(price2,
                                            state_manager,
                                            training=training)

        def training_aware_transform_feature(transformation_cache,
                                             state_manager,
                                             training=None):
            input_tensor = transformation_cache.get(price2.key,
                                                    state_manager,
                                                    training=training)
            if training:
                return input_tensor * 10.0
            else:
                return input_tensor * 20.0

        price2.get_dense_tensor = training_aware_get_dense_tensor
        price2.transform_feature = training_aware_transform_feature
        with ops.Graph().as_default():
            features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
            train_mode = df.DenseFeatures([price1, price2])(features,
                                                            training=True)
            predict_mode = df.DenseFeatures([price1, price2])(features,
                                                              training=False)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 2., 30.], [5., 6., 40.]],
                                self.evaluate(train_mode))
            self.assertAllClose([[1., 2., 60.], [5., 6., 80.]],
                                self.evaluate(predict_mode))

    def test_raises_if_empty_feature_columns(self):
        with self.assertRaisesRegex(ValueError,
                                    'feature_columns must not be empty'):
            df.DenseFeatures(feature_columns=[])(features={})

    def test_should_be_dense_column(self):
        with self.assertRaisesRegex(ValueError, 'must be a .*DenseColumn'):
            df.DenseFeatures(feature_columns=[
                fc.categorical_column_with_hash_bucket('wire_cast', 4)
            ])(features={
                'a': [[0]]
            })

    def test_does_not_support_dict_columns(self):
        with self.assertRaisesRegex(
                ValueError,
                'Expected feature_columns to be iterable, found dict.'):
            df.DenseFeatures(feature_columns={'a': fc.numeric_column('a')})(
                features={
                    'a': [[0]]
                })

    def test_bare_column(self):
        with ops.Graph().as_default():
            features = features = {'a': [0.]}
            net = df.DenseFeatures(fc.numeric_column('a'))(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[0.]], self.evaluate(net))

    def test_column_generator(self):
        with ops.Graph().as_default():
            features = features = {'a': [0.], 'b': [1.]}
            columns = (fc.numeric_column(key) for key in features)
            net = df.DenseFeatures(columns)(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[0., 1.]], self.evaluate(net))

    def test_raises_if_duplicate_name(self):
        with self.assertRaisesRegex(
                ValueError, 'Duplicate feature column name found for columns'):
            df.DenseFeatures(feature_columns=[
                fc.numeric_column('a'),
                fc.numeric_column('a')
            ])(features={
                'a': [[0]]
            })

    def test_one_column(self):
        price = fc.numeric_column('price')
        with ops.Graph().as_default():
            features = {'price': [[1.], [5.]]}
            net = df.DenseFeatures([price])(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1.], [5.]], self.evaluate(net))

    def test_multi_dimension(self):
        price = fc.numeric_column('price', shape=2)
        with ops.Graph().as_default():
            features = {'price': [[1., 2.], [5., 6.]]}
            net = df.DenseFeatures([price])(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))

    def test_compute_output_shape(self):
        price1 = fc.numeric_column('price1', shape=2)
        price2 = fc.numeric_column('price2', shape=4)
        with ops.Graph().as_default():
            features = {
                'price1': [[1., 2.], [5., 6.]],
                'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
            }
            dense_features = df.DenseFeatures([price1, price2])
            self.assertEqual((None, 6),
                             dense_features.compute_output_shape((None, )))
            net = dense_features(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose(
                [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]],
                self.evaluate(net))

    def test_raises_if_shape_mismatch(self):
        price = fc.numeric_column('price', shape=2)
        with ops.Graph().as_default():
            features = {'price': [[1.], [5.]]}
            with self.assertRaisesRegex(
                    Exception,
                    r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'
            ):
                df.DenseFeatures([price])(features)

    def test_reshaping(self):
        price = fc.numeric_column('price', shape=[1, 2])
        with ops.Graph().as_default():
            features = {'price': [[[1., 2.]], [[5., 6.]]]}
            net = df.DenseFeatures([price])(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 2.], [5., 6.]], self.evaluate(net))

    def test_multi_column(self):
        price1 = fc.numeric_column('price1', shape=2)
        price2 = fc.numeric_column('price2')
        with ops.Graph().as_default():
            features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
            net = df.DenseFeatures([price1, price2])(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 2., 3.], [5., 6., 4.]],
                                self.evaluate(net))

    def test_cols_to_output_tensors(self):
        price1 = fc.numeric_column('price1', shape=2)
        price2 = fc.numeric_column('price2')
        with ops.Graph().as_default():
            cols_dict = {}
            features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
            dense_features = df.DenseFeatures([price1, price2])
            net = dense_features(features, cols_dict)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 2.], [5., 6.]],
                                self.evaluate(cols_dict[price1]))
            self.assertAllClose([[3.], [4.]], self.evaluate(cols_dict[price2]))
            self.assertAllClose([[1., 2., 3.], [5., 6., 4.]],
                                self.evaluate(net))

    def test_column_order(self):
        price_a = fc.numeric_column('price_a')
        price_b = fc.numeric_column('price_b')
        with ops.Graph().as_default():
            features = {
                'price_a': [[1.]],
                'price_b': [[3.]],
            }
            net1 = df.DenseFeatures([price_a, price_b])(features)
            net2 = df.DenseFeatures([price_b, price_a])(features)

            self.evaluate(variables_lib.global_variables_initializer())
            self.evaluate(lookup_ops.tables_initializer())

            self.assertAllClose([[1., 3.]], self.evaluate(net1))
            self.assertAllClose([[1., 3.]], self.evaluate(net2))

    def test_fails_for_categorical_column(self):
        animal = fc.categorical_column_with_identity('animal', num_buckets=4)
        with ops.Graph().as_default():
            features = {
                'animal':
                sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1]],
                                           values=[1, 2],
                                           dense_shape=[1, 2])
            }
            with self.assertRaisesRegex(Exception, 'must be a .*DenseColumn'):
                df.DenseFeatures([animal])(features)

    def test_static_batch_size_mismatch(self):
        price1 = fc.numeric_column('price1')
        price2 = fc.numeric_column('price2')
        with ops.Graph().as_default():
            features = {
                'price1': [[1.], [5.], [7.]],  # batchsize = 3
                'price2': [[3.], [4.]]  # batchsize = 2
            }
            with self.assertRaisesRegex(
                    ValueError,
                    r'Batch size \(first dimension\) of each feature must be same.'
            ):  # pylint: disable=anomalous-backslash-in-string
                df.DenseFeatures([price1, price2])(features)

    def test_subset_of_static_batch_size_mismatch(self):
        price1 = fc.numeric_column('price1')
        price2 = fc.numeric_column('price2')
        price3 = fc.numeric_column('price3')
        with ops.Graph().as_default():
            features = {
                'price1':
                array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
                'price2': [[3.], [4.]],  # batchsize = 2
                'price3': [[3.], [4.], [5.]]  # batchsize = 3
            }
            with self.assertRaisesRegex(
                    ValueError,
                    r'Batch size \(first dimension\) of each feature must be same.'
            ):  # pylint: disable=anomalous-backslash-in-string
                df.DenseFeatures([price1, price2, price3])(features)

    def test_runtime_batch_size_mismatch(self):
        price1 = fc.numeric_column('price1')
        price2 = fc.numeric_column('price2')
        with ops.Graph().as_default():
            features = {
                'price1':
                array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
                'price2': [[3.], [4.]]  # batchsize = 2
            }
            net = df.DenseFeatures([price1, price2])(features)
            with _initialized_session() as sess:
                with self.assertRaisesRegex(
                        errors.OpError, 'Dimensions of inputs should match'):
                    sess.run(
                        net,
                        feed_dict={features['price1']: [[1.], [5.], [7.]]})

    def test_runtime_batch_size_matches(self):
        price1 = fc.numeric_column('price1')
        price2 = fc.numeric_column('price2')
        with ops.Graph().as_default():
            features = {
                'price1':
                array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
                'price2':
                array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
            }
            net = df.DenseFeatures([price1, price2])(features)
            with _initialized_session() as sess:
                sess.run(net,
                         feed_dict={
                             features['price1']: [[1.], [5.]],
                             features['price2']: [[1.], [5.]],
                         })

    def test_multiple_layers_with_same_embedding_column(self):
        some_sparse_column = fc.categorical_column_with_hash_bucket(
            'sparse_feature', hash_bucket_size=5)
        some_embedding_column = fc.embedding_column(some_sparse_column,
                                                    dimension=10)

        with ops.Graph().as_default():
            features = {
                'sparse_feature': [['a'], ['x']],
            }
            all_cols = [some_embedding_column]
            df.DenseFeatures(all_cols)(features)
            df.DenseFeatures(all_cols)(features)
            # Make sure that 2 variables get created in this case.
            self.assertEqual(
                2, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
            expected_var_names = [
                'dense_features/sparse_feature_embedding/embedding_weights:0',
                'dense_features_1/sparse_feature_embedding/embedding_weights:0'
            ]
            self.assertItemsEqual(expected_var_names, [
                v.name
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            ])

    def test_multiple_layers_with_same_shared_embedding_column(self):
        categorical_column_a = fc.categorical_column_with_identity(
            key='aaa', num_buckets=3)
        categorical_column_b = fc.categorical_column_with_identity(
            key='bbb', num_buckets=3)
        embedding_dimension = 2

        # feature_column.shared_embeddings is not supported in eager.
        with ops.Graph().as_default():
            embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
                [categorical_column_b, categorical_column_a],
                dimension=embedding_dimension)
            features = {
                'aaa':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(0, 1, 0),
                                           dense_shape=(2, 2)),
                'bbb':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(1, 2, 1),
                                           dense_shape=(2, 2)),
            }
            all_cols = [embedding_column_a, embedding_column_b]
            df.DenseFeatures(all_cols)(features)
            df.DenseFeatures(all_cols)(features)
            # Make sure that only 1 variable gets created in this case.
            self.assertEqual(
                1, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
            self.assertItemsEqual(['aaa_bbb_shared_embedding:0'], [
                v.name
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            ])

    def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(
            self):
        categorical_column_a = fc.categorical_column_with_identity(
            key='aaa', num_buckets=3)
        categorical_column_b = fc.categorical_column_with_identity(
            key='bbb', num_buckets=3)
        embedding_dimension = 2

        # feature_column.shared_embeddings is not supported in eager.
        with ops.Graph().as_default():
            embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
                [categorical_column_b, categorical_column_a],
                dimension=embedding_dimension)
            all_cols = [embedding_column_a, embedding_column_b]
            features = {
                'aaa':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(0, 1, 0),
                                           dense_shape=(2, 2)),
                'bbb':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(1, 2, 1),
                                           dense_shape=(2, 2)),
            }
            df.DenseFeatures(all_cols)(features)
            # Make sure that only 1 variable gets created in this case.
            self.assertEqual(
                1, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))

        with ops.Graph().as_default():
            features1 = {
                'aaa':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(0, 1, 0),
                                           dense_shape=(2, 2)),
                'bbb':
                sparse_tensor.SparseTensor(indices=((0, 0), (1, 0), (1, 1)),
                                           values=(1, 2, 1),
                                           dense_shape=(2, 2)),
            }

            df.DenseFeatures(all_cols)(features1)
            # Make sure that only 1 variable gets created in this case.
            self.assertEqual(
                1, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
            self.assertItemsEqual(['aaa_bbb_shared_embedding:0'], [
                v.name
                for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            ])

    def test_with_1d_sparse_tensor(self):
        embedding_values = (
            (1., 2., 3., 4., 5.),  # id 0
            (6., 7., 8., 9., 10.),  # id 1
            (11., 12., 13., 14., 15.)  # id 2
        )

        def _initializer(shape, dtype, partition_info=None):
            del shape, dtype, partition_info
            return embedding_values

        # price has 1 dimension in dense_features
        price = fc.numeric_column('price')

        # one_hot_body_style has 3 dims in dense_features.
        body_style = fc.categorical_column_with_vocabulary_list(
            'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
        one_hot_body_style = fc.indicator_column(body_style)

        # embedded_body_style has 5 dims in dense_features.
        country = fc.categorical_column_with_vocabulary_list(
            'country', vocabulary_list=['US', 'JP', 'CA'])
        embedded_country = fc.embedding_column(country,
                                               dimension=5,
                                               initializer=_initializer)

        with ops.Graph().as_default():
            # Provides 1-dim tensor and dense tensor.
            features = {
                'price':
                constant_op.constant([
                    11.,
                    12.,
                ]),
                'body-style':
                sparse_tensor.SparseTensor(indices=((0, ), (1, )),
                                           values=('sedan', 'hardtop'),
                                           dense_shape=(2, )),
                # This is dense tensor for the categorical_column.
                'country':
                constant_op.constant(['CA', 'US']),
            }
            self.assertEqual(1, features['price'].shape.ndims)
            self.assertEqual(1,
                             features['body-style'].dense_shape.get_shape()[0])
            self.assertEqual(1, features['country'].shape.ndims)

            net = df.DenseFeatures(
                [price, one_hot_body_style, embedded_country])(features)
            self.assertEqual(1 + 3 + 5, net.shape[1])
            with _initialized_session() as sess:

                # Each row is formed by concatenating `embedded_body_style`,
                # `one_hot_body_style`, and `price` in order.
                self.assertAllEqual(
                    [[0., 0., 1., 11., 12., 13., 14., 15., 11.],
                     [1., 0., 0., 1., 2., 3., 4., 5., 12.]], sess.run(net))

    def test_with_1d_unknown_shape_sparse_tensor(self):
        embedding_values = (
            (1., 2.),  # id 0
            (6., 7.),  # id 1
            (11., 12.)  # id 2
        )

        def _initializer(shape, dtype, partition_info=None):
            del shape, dtype, partition_info
            return embedding_values

        # price has 1 dimension in dense_features
        price = fc.numeric_column('price')

        # one_hot_body_style has 3 dims in dense_features.
        body_style = fc.categorical_column_with_vocabulary_list(
            'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
        one_hot_body_style = fc.indicator_column(body_style)

        # embedded_body_style has 5 dims in dense_features.
        country = fc.categorical_column_with_vocabulary_list(
            'country', vocabulary_list=['US', 'JP', 'CA'])
        embedded_country = fc.embedding_column(country,
                                               dimension=2,
                                               initializer=_initializer)

        # Provides 1-dim tensor and dense tensor.
        with ops.Graph().as_default():
            features = {
                'price': array_ops.placeholder(dtypes.float32),
                'body-style': array_ops.sparse_placeholder(dtypes.string),
                # This is dense tensor for the categorical_column.
                'country': array_ops.placeholder(dtypes.string),
            }
            self.assertIsNone(features['price'].shape.ndims)
            self.assertIsNone(features['body-style'].get_shape().ndims)
            self.assertIsNone(features['country'].shape.ndims)

            price_data = np.array([11., 12.])
            body_style_data = sparse_tensor.SparseTensorValue(
                indices=((0, ), (1, )),
                values=('sedan', 'hardtop'),
                dense_shape=(2, ))
            country_data = np.array([['US'], ['CA']])

            net = df.DenseFeatures(
                [price, one_hot_body_style, embedded_country])(features)
            self.assertEqual(1 + 3 + 2, net.shape[1])
            with _initialized_session() as sess:

                # Each row is formed by concatenating `embedded_body_style`,
                # `one_hot_body_style`, and `price` in order.
                self.assertAllEqual(
                    [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
                    sess.run(net,
                             feed_dict={
                                 features['price']: price_data,
                                 features['body-style']: body_style_data,
                                 features['country']: country_data
                             }))

    def test_with_rank_0_feature(self):
        # price has 1 dimension in dense_features
        price = fc.numeric_column('price')
        features = {
            'price': constant_op.constant(0),
        }
        self.assertEqual(0, features['price'].shape.ndims)

        # Static rank 0 should fail
        with self.assertRaisesRegex(ValueError,
                                    'Feature .* cannot have rank 0'):
            df.DenseFeatures([price])(features)

        with ops.Graph().as_default():
            # Dynamic rank 0 should fail
            features = {
                'price': array_ops.placeholder(dtypes.float32),
            }
            net = df.DenseFeatures([price])(features)
            self.assertEqual(1, net.shape[1])
            with _initialized_session() as sess:
                with self.assertRaisesOpError('Feature .* cannot have rank 0'):
                    sess.run(net, feed_dict={features['price']: np.array(1)})
Beispiel #11
0
class DropoutTest(test.TestCase, parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testDropoutProperties(self):
        dp = core_layers.Dropout(0.5, name='dropout')
        self.assertEqual(dp.rate, 0.5)
        self.assertEqual(dp.noise_shape, None)
        dp.apply(array_ops.ones(()))
        self.assertEqual(dp.name, 'dropout')

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testBooleanLearningPhase(self):
        dp = core_layers.Dropout(0.5)
        inputs = array_ops.ones((5, 3))
        dropped = dp.apply(inputs, training=True)
        if not context.executing_eagerly():
            self.evaluate(variables.global_variables_initializer())
        np_output = self.evaluate(dropped)
        self.assertAlmostEqual(0., np_output.min())
        dropped = dp.apply(inputs, training=False)
        np_output = self.evaluate(dropped)
        self.assertAllClose(np.ones((5, 3)), np_output)

    @test_util.run_deprecated_v1
    def testDynamicLearningPhase(self):
        with self.cached_session() as sess:
            dp = core_layers.Dropout(0.5, seed=1)
            inputs = array_ops.ones((5, 5))
            training = array_ops.placeholder(dtype='bool')
            dropped = dp.apply(inputs, training=training)
            self.evaluate(variables.global_variables_initializer())
            np_output = sess.run(dropped, feed_dict={training: True})
            self.assertAlmostEqual(0., np_output.min())
            np_output = sess.run(dropped, feed_dict={training: False})
            self.assertAllClose(np.ones((5, 5)), np_output)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def testDynamicNoiseShape(self):
        inputs = array_ops.ones((5, 3, 2))
        noise_shape = [None, 1, None]
        dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1)
        dropped = dp.apply(inputs, training=True)
        self.evaluate(variables.global_variables_initializer())
        np_output = self.evaluate(dropped)
        self.assertAlmostEqual(0., np_output.min())
        self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])

    def testCustomNoiseShape(self):
        inputs = array_ops.ones((5, 3, 2))
        noise_shape = [5, 1, 2]
        dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1)
        dropped = dp.apply(inputs, training=True)
        self.evaluate(variables.global_variables_initializer())
        np_output = self.evaluate(dropped)
        self.assertAlmostEqual(0., np_output.min())
        self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])

    @test_util.run_deprecated_v1
    def testFunctionalDropout(self):
        with self.cached_session():
            inputs = array_ops.ones((5, 5))
            dropped = core_layers.dropout(inputs, 0.5, training=True, seed=1)
            self.evaluate(variables.global_variables_initializer())
            np_output = self.evaluate(dropped)
            self.assertAlmostEqual(0., np_output.min())
            dropped = core_layers.dropout(inputs, 0.5, training=False, seed=1)
            np_output = self.evaluate(dropped)
            self.assertAllClose(np.ones((5, 5)), np_output)

    @test_util.run_deprecated_v1
    def testDynamicRate(self):
        with self.cached_session() as sess:
            rate = array_ops.placeholder(dtype='float32', name='rate')
            dp = core_layers.Dropout(rate, name='dropout')
            inputs = array_ops.ones((5, 5))
            dropped = dp.apply(inputs, training=True)
            self.evaluate(variables.global_variables_initializer())
            np_output = sess.run(dropped, feed_dict={rate: 0.5})
            self.assertAlmostEqual(0., np_output.min())
            np_output = sess.run(dropped, feed_dict={rate: 0.0})
            self.assertAllClose(np.ones((5, 5)), np_output)
class OptimizerTest(test.TestCase, parameterized.TestCase):

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testBasic(self):
    for dtype in _DATA_TYPES:
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
        sgd = gradient_descent.SGD(3.0)

        self.evaluate(variables.global_variables_initializer())
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
        # Run 1 step of sgd through optimizer
        opt_op = sgd.minimize(loss, var_list=[var0, var1])
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(opt_op)
        # Validate updated params
        self.assertAllClose([-14., -13.], self.evaluate(var0))
        self.assertAllClose([-6., -5.], self.evaluate(var1))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testAdaptiveLearningRate(self):
    for dtype in _DATA_TYPES:
      with self.test_session():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)

        def loss():
          return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop

        sgd = gradient_descent.SGD(1.0)

        self.evaluate(variables.global_variables_initializer())
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
        # Run 1 step of sgd through optimizer
        opt_op = sgd.minimize(loss, [var0, var1])
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(opt_op)
        # Validate updated params
        # var0 = [1., 2.] - 1.0 * [5, 5]
        self.assertAllClose([-4., -3.], self.evaluate(var0))
        # var1 = [3., 4.] - 1.0 * [3, 3]
        self.assertAllClose([0., 1.], self.evaluate(var1))

        sgd.learning_rate = 0.5
        if context.executing_eagerly():
          sgd.minimize(loss, [var0, var1])
        else:
          self.evaluate(opt_op)
        # Validate updated params
        # var0 = [-4., -3.] - 0.5 * [5, 5]
        self.assertAllClose([-6.5, -5.5], self.evaluate(var0))
        # var1 = [0., 1.] - 0.5 * [3, 3]
        self.assertAllClose([-1.5, -0.5], self.evaluate(var1))

        sgd.learning_rate = learning_rate_schedule.InverseTimeDecay(
            0.5, decay_steps=1.0, decay_rate=0.5)
        if context.executing_eagerly():
          sgd.minimize(loss, [var0, var1])
        else:
          self.evaluate(opt_op)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testPrecomputedGradient(self):
    for dtype in _DATA_TYPES:
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
        grad_loss = constant_op.constant([42, -42], dtype=dtype)
        sgd = gradient_descent.SGD(3.0)

        self.evaluate(variables.global_variables_initializer())
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
        # Run 1 step of sgd through optimizer
        opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(opt_op)
        # Validate updated params
        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
                            self.evaluate(var0))
        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
                            self.evaluate(var1))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testNoGradients(self):
    for dtype in _DATA_TYPES:
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        loss = lambda: 5 * var0  # pylint: disable=cell-var-from-loop
        sgd_op = gradient_descent.SGD(3.0)
        with self.assertRaisesRegex(ValueError, 'No gradients'):
          # var1 has no gradient
          sgd_op.minimize(loss, var_list=[var1])

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testNoGradientsForAnyVariables_Minimize(self):
    for dtype in _DATA_TYPES:
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        loss = lambda: constant_op.constant(5.0)

        sgd_op = gradient_descent.SGD(3.0)
        with self.assertRaisesRegex(ValueError,
                                    'No gradients provided for any variable'):
          sgd_op.minimize(loss, var_list=[var0, var1])

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testNoGradientsForAnyVariables_ApplyGradients(self):
    for dtype in _DATA_TYPES:
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        sgd_op = gradient_descent.SGD(3.0)
        with self.assertRaisesRegex(ValueError,
                                    'No gradients provided for any variable'):
          sgd_op.apply_gradients([(None, var0), (None, var1)])

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGradientsAsVariables(self):
    for i, dtype in enumerate(_DATA_TYPES):
      with testing_utils.use_gpu():
        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop

        sgd = gradient_descent.SGD(3.0)
        grads_and_vars = sgd._compute_gradients(loss, [var0, var1])
        # Convert gradients to tf.Variables
        converted_grads = [
            variables.Variable(
                array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
            for j, gv in enumerate(grads_and_vars)
        ]
        convert_ops = [
            state_ops.assign(converted_grads[j], gv[0])
            for j, gv in enumerate(grads_and_vars)
        ]

        # Run convert_ops to achieve the gradients converting
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(convert_ops)
        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
        self.assertAllClose([3.0, 4.0], self.evaluate(var1))

        # Run 1 step of sgd through optimizer
        converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
        opt_op = sgd.apply_gradients(converted_grads_and_vars)
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(convert_ops)
        self.evaluate(opt_op)

        # Validate updated params
        self.assertAllClose([-14., -13.], self.evaluate(var0))
        self.assertAllClose([-6., -5.], self.evaluate(var1))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testComputeGradientsWithTensors(self):
    with testing_utils.use_gpu():
      x = ops.convert_to_tensor_v2_with_dispatch(1.0)

      def f():
        return x * x

      sgd = gradient_descent.SGD(3.0)
      grads_and_vars = sgd._compute_gradients(f, [x])
      self.assertLen(grads_and_vars, 1)
      grad, x_as_var = grads_and_vars[0]
      self.assertIs(x, x_as_var)
      self.assertEqual(2.0, self.evaluate(grad))

      with self.assertRaises(NotImplementedError):
        sgd.apply_gradients(grads_and_vars)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testConstraint(self):
    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
    with testing_utils.use_gpu():
      var0 = variables.Variable([1.0, 2.0],
                                constraint=constraint_01)
      var1 = variables.Variable([3.0, 4.0],
                                constraint=constraint_0)
      loss = lambda: 5 * var0 + 3 * var1
      sgd = gradient_descent.SGD(3.0)

      self.evaluate(variables.global_variables_initializer())
      # Fetch params to validate initial values
      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
      # Run 1 step of sgd through optimizer
      opt_op = sgd.minimize(loss, var_list=[var0, var1])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)
      # Validate updated params
      self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
      self.assertAllClose([0., 0.], self.evaluate(var1))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testIterationWithoutMinimize(self):
    with testing_utils.use_gpu():
      sgd = gradient_descent.SGD(3.0)
      self.evaluate(sgd.iterations.initializer)
      self.assertEqual(0, self.evaluate(sgd.iterations))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testConfig(self):
    with testing_utils.use_gpu():
      opt = gradient_descent.SGD(learning_rate=1.0)
      config = opt.get_config()
      opt2 = gradient_descent.SGD.from_config(config)
      lr = opt._get_hyper('learning_rate')
      lr2 = opt2._get_hyper('learning_rate')
      self.evaluate(variables.global_variables_initializer())
      # assert both are equal float values.
      self.assertEqual(self.evaluate(lr), self.evaluate(lr2))
      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
      loss = lambda: 3 * var0
      # learning rate variable created when calling minimize.
      opt.minimize(loss, [var0])
      opt3 = gradient_descent.SGD.from_config(config)
      lr3 = opt3._get_hyper('learning_rate')
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(self.evaluate(lr), self.evaluate(lr3))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testConfigWithLearningRateDecay(self):
    with testing_utils.use_gpu():
      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
      for decay_schedule in [
          learning_rate_schedule.InverseTimeDecay(
              0.5, decay_steps=1.0, decay_rate=0.1),
          learning_rate_schedule.PiecewiseConstantDecay(
              [5], [1., .5])
      ]:
        step = 10
        opt = gradient_descent.SGD(decay_schedule)
        config = opt.get_config()
        opt2 = gradient_descent.SGD.from_config(config)
        # assert both are equal float values.
        self.assertAllEqual(
            decay_schedule(step),
            opt._get_hyper('learning_rate')(step))
        self.assertAllEqual(
            decay_schedule(step),
            opt2._get_hyper('learning_rate')(step))
        loss = lambda: 3 * var0
        # learning rate variable is created when calling minimize.
        opt.minimize(loss, [var0])
        self.evaluate(variables.global_variables_initializer())
        config = opt.get_config()
        opt3 = gradient_descent.SGD.from_config(config)
        self.assertAllEqual(
            self.evaluate(opt._get_hyper('learning_rate')(step)),
            opt3._get_hyper('learning_rate')(step))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGradClipValue(self):
    with testing_utils.use_gpu():
      var = variables.Variable([1.0, 2.0])
      loss = lambda: 3 * var
      opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
      opt_op = opt.minimize(loss, [var])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)
      self.assertAllClose([0., 1.], self.evaluate(var))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGradClipNorm(self):
    with testing_utils.use_gpu():
      var = variables.Variable([1.0])
      loss = lambda: 3 * var
      opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
      opt_op = opt.minimize(loss, [var])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)
      self.assertAllClose([0.], self.evaluate(var))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGradGlobalClipNorm(self):
    with testing_utils.use_gpu():
      # l2 norm is 5.0
      var1 = variables.Variable([1.0])
      var2 = variables.Variable([2.0])
      loss = lambda: 3 * var1 + 4 * var2
      opt = gradient_descent.SGD(learning_rate=1.0, global_clipnorm=2.0)
      opt_op = opt.minimize(loss, [var1, var2])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)
      # grad1 = 3.0 * 2.0 / 5.0 = 1.2
      self.assertAllClose([-.2], self.evaluate(var1))
      # grad2 = 4.0 * 2.0 / 5.0 = 1.6
      self.assertAllClose([.4], self.evaluate(var2))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testInvalidClipNorm(self):
    with self.assertRaisesRegex(ValueError, '>= 0'):
      gradient_descent.SGD(learning_rate=1.0, clipnorm=-1.0)

  @combinations.generate(
      combinations.combine(
          mode=['graph', 'eager'],
          clip_type=['clipnorm', 'global_clipnorm', 'clipvalue']))
  def testConfigWithCliping(self, clip_type):
    opt = gradient_descent.SGD(learning_rate=1.0, **{clip_type: 2.0})
    config = opt.get_config()
    opt = gradient_descent.SGD.from_config(config)
    self.assertEqual(getattr(opt, clip_type), 2.0)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testInvalidKwargs(self):
    with self.assertRaisesRegex(TypeError, 'Unexpected keyword argument'):
      gradient_descent.SGD(learning_rate=1.0, invalidkwargs=1.0)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testWeights(self):
    with testing_utils.use_gpu():
      opt1 = adam.Adam(learning_rate=1.0)
      var1 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
      loss1 = lambda: 3 * var1
      opt_op_1 = opt1.minimize(loss1, [var1])
      self.evaluate(variables.global_variables_initializer())
      config = opt1.get_config()
      opt2 = adam.Adam.from_config(config)
      var2 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
      loss2 = lambda: 3 * var2
      opt_op_2 = opt2.minimize(loss2, [var2])
      weights = opt1.get_weights()

      # Assert set_weights and both variables get updated to same value.
      self.evaluate(variables.global_variables_initializer())
      opt2.set_weights(weights)
      self.evaluate([opt_op_1, opt_op_2])
      self.assertAllClose(self.evaluate(var1), self.evaluate(var2))
      self.assertEqual(1, self.evaluate(opt1.iterations))
      self.assertEqual(1, self.evaluate(opt2.iterations))

      var3 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
      var4 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
      loss3 = lambda: 3 * var3 + 5 * var4
      opt_op_3 = opt1.minimize(loss3, [var3, var4])

      # Assert set_weights with ValueError since weight list does not match.
      self.evaluate(variables.global_variables_initializer())
      weights = opt1.get_weights()
      with self.assertRaisesRegex(ValueError, 'but the optimizer was'):
        opt2.set_weights(weights)

      # Assert set_weights and variables get updated to same value.
      var5 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
      var6 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
      loss4 = lambda: 3 * var5 + 5 * var6
      opt_op_4 = opt2.minimize(loss4, [var5, var6])
      self.evaluate(variables.global_variables_initializer())
      opt2.set_weights(weights)
      self.evaluate([opt_op_3, opt_op_4])
      self.assertAllClose(
          self.evaluate([var3, var4]), self.evaluate([var5, var6]))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGettingHyperParameters(self):
    with self.test_session():
      opt = adam.Adam(learning_rate=1.0)
      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
      loss = lambda: 3 * var
      opt_op = opt.minimize(loss, [var])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)

      lr = self.evaluate(opt.lr)
      self.assertEqual(1.0, lr)

      opt.lr = 2.0
      lr = self.evaluate(opt.lr)
      self.assertEqual(2.0, lr)

      self.evaluate(opt.lr.assign(3.0))
      lr = self.evaluate(opt.lr)
      self.assertEqual(3.0, lr)

      with self.assertRaises(AttributeError):
        opt.not_an_attr += 3

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testGettingHyperParametersWithLrInConstructor(self):
    with self.test_session():
      opt = gradient_descent.SGD(lr=3.0)
      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
      loss = lambda: 3 * var
      opt_op = opt.minimize(loss, [var])
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(opt_op)

      self.assertIsInstance(opt.lr, variables.Variable)
      self.assertIsInstance(opt.learning_rate, variables.Variable)

      lr = self.evaluate(opt.lr)
      self.assertEqual(3.0, lr)

      opt.lr = 2.0
      lr = self.evaluate(opt.lr)
      self.assertEqual(2.0, lr)

      self.evaluate(opt.lr.assign(4.0))
      lr = self.evaluate(opt.lr)
      self.assertEqual(4.0, lr)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testDir(self):
    opt = gradient_descent.SGD(learning_rate=1.0, momentum=0.1)
    dir_result = set(dir(opt))
    self.assertIn('learning_rate', dir_result)  # Hyperparameter
    self.assertIn('lr', dir_result)  # Hyperparameter
    self.assertIn('momentum', dir_result)  # Hyperparameter
    self.assertIn('nesterov', dir_result)  # Attribute
    self.assertIn('minimize', dir_result)  # Attribute

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testOptimizerWithKerasModel(self):
    a = input_layer.Input(shape=(3,), name='input_a')
    b = input_layer.Input(shape=(3,), name='input_b')

    dense = core.Dense(4, name='dense')
    c = dense(a)
    d = dense(b)
    e = core.Dropout(0.5, name='dropout')(c)

    model = training.Model([a, b], [d, e])

    optimizer = gradient_descent.SGD(learning_rate=0.001)
    loss = 'mse'
    model.compile(optimizer, loss, metrics=['mae'])

    input_a_np = np.random.random((10, 3))
    input_b_np = np.random.random((10, 3))

    output_d_np = np.random.random((10, 4))
    output_e_np = np.random.random((10, 4))

    model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
              epochs=1,
              batch_size=5)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testOptimizerWithCallbacks(self):
    np.random.seed(1331)
    input_np = np.random.random((10, 3))
    output_np = np.random.random((10, 4))
    a = input_layer.Input(shape=(3,), name='input_a')
    model = sequential.Sequential()
    model.add(core.Dense(4, kernel_initializer='zeros', name='dense'))
    model.add(core.Dropout(0.5, name='dropout'))
    model(a)
    optimizer = gradient_descent.SGD(learning_rate=0.1)
    model.compile(optimizer, loss='mse', metrics=['mae'])
    # This does not reduce the LR after the first epoch (due to low delta).
    cbks = [
        callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=0.1, min_delta=0, patience=1, cooldown=5)
    ]
    model.fit(
        input_np,
        output_np,
        batch_size=10,
        validation_data=(input_np, output_np),
        callbacks=cbks,
        epochs=2,
        verbose=0)
    self.assertAllClose(
        float(backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)

    # This should reduce the LR after the first epoch (due to high delta).
    cbks = [
        callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.1,
            min_delta=10,
            patience=1,
            cooldown=5)
    ]
    model.fit(
        input_np,
        output_np,
        batch_size=10,
        validation_data=(input_np, output_np),
        callbacks=cbks,
        epochs=2,
        verbose=2)
    self.assertAllClose(
        float(backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)

  def testOptimizerSetIterations(self):
    global_step = training_util.get_or_create_global_step()
    opt = adam.Adam(learning_rate=1.0)
    opt.iterations = global_step
    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
    self.evaluate(variables.global_variables_initializer())
    init_step_value = self.evaluate(global_step)
    loss = lambda: 3 * var
    opt_op = opt.minimize(loss, [var])
    self.evaluate(variables.global_variables_initializer())
    self.evaluate(opt_op)
    new_step_value = self.evaluate(global_step)
    self.assertEqual(new_step_value, init_step_value + 1)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testOptimizerWithCallableVarList(self):
    train_samples = 20
    input_dim = 1
    num_classes = 2
    (x, y), _ = testing_utils.get_test_data(
        train_samples=train_samples,
        test_samples=10,
        input_shape=(input_dim,),
        num_classes=num_classes)
    y = np_utils.to_categorical(y)

    num_hidden = 1
    model = testing_utils.get_small_sequential_mlp(
        num_hidden=num_hidden, num_classes=num_classes)
    opt = adam.Adam()

    loss = lambda: losses.mean_squared_error(model(x), y)
    var_list = lambda: model.trainable_weights

    with self.assertRaisesRegex(
        ValueError, 'Weights for model .* have not yet been created'):
      var_list()
    train_op = opt.minimize(loss, var_list)
    if not context.executing_eagerly():
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(
          [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
      self.evaluate(train_op)
    self.assertNotEqual(
        [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
    self.assertLen(var_list(), 4)

  def testVarKey(self):
    with ops.get_default_graph().as_default():
      a = variables.Variable([1., 2.], name='var')
      b = variables.Variable([1.], name='var')
      self.assertTrue(a._in_graph_mode)
      self.assertTrue(b._in_graph_mode)
      var_key = optimizer_v2._var_key(a)
      self.assertEqual('var', var_key)
      var_key = optimizer_v2._var_key(b)
      self.assertEqual('var_1', var_key)

  def testVarName(self):
    with ops.get_default_graph().as_default():
      var = variables.Variable([1., 2.], name='var')
      loss = var + 1.
      opt = adam.Adam()
      opt.get_updates(loss, [var])
      opt_vars = opt.variables()
      self.assertLen(opt_vars, 3)
      self.assertEqual('Adam/iter:0', opt_vars[0].name)
      self.assertEqual('Adam/var/m:0', opt_vars[1].name)
      var_2 = variables.Variable([1., 2.], name='var_2')
      loss = var_2 + 1.
      with backend.name_scope('outter'):
        opt.get_updates(loss, [var_2])
      opt_vars = opt.variables()
      self.assertLen(opt_vars, 5)
      self.assertEqual('outter/Adam/var_2/m:0', opt_vars[3].name)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testEmptyVarList(self):
    opt = gradient_descent.SGD(1.)
    opt.minimize(lambda: constant_op.constant(1.), [])
    opt.apply_gradients([])

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testAggregationTrue(self):
    # Test that experimental_aggregate_gradients=True works without distributed
    # strategy.
    var = variables.Variable([1., 2.])
    opt = gradient_descent.SGD(3.0)

    self.evaluate(variables.global_variables_initializer())
    self.assertAllClose([1., 2.], self.evaluate(var))
    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
                                 experimental_aggregate_gradients=True)
    self.evaluate(variables.global_variables_initializer())
    self.evaluate(opt_op)
    self.assertAllClose([0.7, 1.7], self.evaluate(var))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def testAggregationFalse(self):
    # Test that experimental_aggregate_gradients=False works without distributed
    # strategy.
    var = variables.Variable([1., 2.])
    opt = gradient_descent.SGD(3.0)

    self.evaluate(variables.global_variables_initializer())
    self.assertAllClose([1., 2.], self.evaluate(var))
    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
                                 experimental_aggregate_gradients=False)
    self.evaluate(variables.global_variables_initializer())
    self.evaluate(opt_op)
    self.assertAllClose([0.7, 1.7], self.evaluate(var))

  @combinations.generate(combinations.combine(mode=['eager']))
  def testRestoringIterationsWithoutAnOptimizer(self):
    opt = gradient_descent.SGD(3.0)
    opt.iterations.assign(5)
    checkpoint = trackable_utils.Checkpoint(optimizer=opt)
    path = checkpoint.save(self.get_temp_dir())

    # Following verifies that the `iterations` can be restored with the absence
    # of an `Optimizer` object (using a `Checkpoint` as a placeholder).
    iterations_var = variables.Variable(0, dtype=dtypes.int64)
    optimizer_checkpoint = trackable_utils.Checkpoint(iter=iterations_var)
    checkpoint_to_restore = trackable_utils.Checkpoint(
        optimizer=optimizer_checkpoint)
    checkpoint_to_restore.restore(path)

    self.assertEqual(5, self.evaluate(iterations_var))

  @combinations.generate(combinations.combine(mode=['eager']))
  def testSlotWithNonstandardShapeRestoresBasedOnCheckpoint(self):
    # First create an optimizer and a slot variable with a non-standard shape.
    x = variables.Variable([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
    slot_shape = [2, 1]
    optimizer_1 = optimizer_v2.OptimizerV2(name='test')
    optimizer_1.add_slot(x, 'test_slot', 'ones', shape=slot_shape)

    # Then save the variable and optimizer to a checkpoint.
    checkpoint_1 = trackable_utils.Checkpoint(var=x, optimizer=optimizer_1)
    checkpoint_path = checkpoint_1.save(self.get_temp_dir())

    # Create a new optimizer and call restore on it (and x)
    optimizer_2 = optimizer_v2.OptimizerV2(name='test')
    checkpoint_2 = trackable_utils.Checkpoint(var=x, optimizer=optimizer_2)
    checkpoint_2.restore(checkpoint_path)

    self.assertEqual(slot_shape,
                     optimizer_2.get_slot(x, 'test_slot').shape.as_list())

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_gradient_aggregator(self):
    def gradient_aggregator(grads_and_vars):
      # Simulate an all-reduce where the other replica has zeros for gradients,
      # by dividing each gradient by 2.
      grads = [g for g, _ in grads_and_vars]
      vars = [v for _, v in grads_and_vars]  # pylint: disable=redefined-builtin
      all_reduced_grads = [g / 2 for g in grads]
      return list(zip(all_reduced_grads, vars))

    var = variables.Variable(2.0)
    sgd = gradient_descent.SGD(1.0, gradient_aggregator=gradient_aggregator)
    loss = lambda: 2 * var
    opt_op = sgd.minimize(loss, var_list=[var])
    self.evaluate(variables.global_variables_initializer())
    self.evaluate(opt_op)
    self.assertEqual(self.evaluate(var), 1.0)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_override_aggregate_gradients(self):
    class MyOptimizer(gradient_descent.SGD):

      def _aggregate_gradients(self, grads_and_vars):
        # Simulate an all-reduce where the other replica has zeros for
        # gradients, by dividing each gradient by 2.
        grads = [g for g, _ in grads_and_vars]
        vars = [v for _, v in grads_and_vars]  # pylint: disable=redefined-builtin
        all_reduced_grads = [g / 2 for g in grads]
        return list(zip(all_reduced_grads, vars))

    var = variables.Variable(2.0)
    sgd = MyOptimizer(1.0)
    loss = lambda: 2 * var
    opt_op = sgd.minimize(loss, var_list=[var])
    self.evaluate(variables.global_variables_initializer())
    self.evaluate(opt_op)
    self.assertEqual(self.evaluate(var), 1.0)
Beispiel #13
0
class CheckpointingTests(keras_parameterized.TestCase):
    @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
    def testNamingWithOptimizer(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        # A nuisance Model using the same optimizer. Its slot variables should not
        # go in the checkpoint, since it is never depended on.
        other_model = MyModel()
        optimizer = adam.AdamOptimizer(0.001)
        optimizer_step = training_util.get_or_create_global_step()
        root_trackable = trackable_utils.Checkpoint(
            optimizer=optimizer, model=model, optimizer_step=optimizer_step)
        if context.executing_eagerly():
            optimizer.minimize(lambda: model(input_value),
                               global_step=optimizer_step)
            optimizer.minimize(lambda: other_model(input_value),
                               global_step=optimizer_step)
        else:
            train_op = optimizer.minimize(model(input_value),
                                          global_step=optimizer_step)
            optimizer.minimize(other_model(input_value),
                               global_step=optimizer_step)
            self.evaluate(trackable_utils.gather_initializers(root_trackable))
            self.evaluate(train_op)
        named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
            root_trackable).serialize_object_graph()
        expected_checkpoint_names = (
            # Created in the root node, so no prefix.
            "optimizer_step",
            "model/_second/kernel",
            "model/_named_dense/kernel",
            "model/_named_dense/bias",
            # non-Layer dependency of the model
            "model/_non_layer/a_variable",
            # The optimizer creates two non-slot variables
            "optimizer/beta1_power",
            "optimizer/beta2_power",
            # Slot variables
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
        )
        suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
        expected_checkpoint_names = [
            name + suffix for name in expected_checkpoint_names
        ]
        named_variables = {v.name: v for v in named_variables}
        six.assertCountEqual(self, expected_checkpoint_names,
                             named_variables.keys())
        # Check that we've mapped to the right variable objects (not exhaustive)
        self.assertEqual("global_step",
                         named_variables["optimizer_step" + suffix].full_name)
        self.assertEqual(
            "my_model/dense_1/kernel",
            named_variables["model/_second/kernel" + suffix].full_name)
        self.assertEqual(
            "my_model/dense/kernel",
            named_variables["model/_named_dense/kernel" + suffix].full_name)
        self.assertEqual(
            "beta1_power",
            named_variables["optimizer/beta1_power" + suffix].full_name)
        self.assertEqual(
            "beta2_power",
            named_variables["optimizer/beta2_power" + suffix].full_name)
        # Spot check the generated protocol buffers.
        self.assertEqual("optimizer",
                         serialized_graph.nodes[0].children[1].local_name)
        optimizer_node = serialized_graph.nodes[
            serialized_graph.nodes[0].children[1].node_id]
        self.assertEqual("beta1_power", optimizer_node.children[0].local_name)
        self.assertEqual(
            "beta1_power", serialized_graph.nodes[
                optimizer_node.children[0].node_id].attributes[0].full_name)
        self.assertEqual(
            "my_model/dense/kernel",
            serialized_graph.nodes[optimizer_node.slot_variables[
                0].original_variable_node_id].attributes[0].full_name)
        # We strip off the :0 suffix, as variable.name-based saving does.
        self.assertEqual(
            "my_model/dense/kernel/Adam",
            serialized_graph.nodes[optimizer_node.slot_variables[
                0].slot_variable_node_id].attributes[0].full_name)
        self.assertEqual(
            "my_model/dense/kernel/Adam:0",
            optimizer.get_slot(var=model._named_dense.kernel, name="m").name)
        self.assertEqual(
            "model/_named_dense/kernel" + suffix,
            serialized_graph.nodes[optimizer_node.slot_variables[
                0].original_variable_node_id].attributes[0].checkpoint_key)
        self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
        self.assertEqual(
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
            serialized_graph.nodes[optimizer_node.slot_variables[
                0].slot_variable_node_id].attributes[0].checkpoint_key)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testSaveRestore(self):
        with self.test_session():
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                        model=model)
            input_value = constant_op.constant([[3.]])
            if context.executing_eagerly():
                optimizer.minimize(lambda: model(input_value))
            else:
                train_op = optimizer.minimize(model(input_value))
                # TODO(allenl): Make initialization more pleasant when graph building.
                root_trackable.save_counter  # pylint: disable=pointless-statement
                self.evaluate(
                    trackable_utils.gather_initializers(root_trackable))
                self.evaluate(train_op)
            prefix = os.path.join(self.get_temp_dir(), "ckpt")
            self.evaluate(
                state_ops.assign(model._named_dense.variables[1], [42.]))
            m_bias_slot = optimizer.get_slot(model._named_dense.variables[1],
                                             "m")
            self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
            save_path = root_trackable.save(file_prefix=prefix)
            self.evaluate(
                state_ops.assign(model._named_dense.variables[1], [43.]))
            self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
            optimizer_variables = self.evaluate(optimizer.variables())
            self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
            # Immediate restoration
            status = root_trackable.restore(
                save_path=save_path).assert_consumed()
            status.run_restore_ops()
            self.assertAllEqual([42.],
                                self.evaluate(model._named_dense.variables[1]))
            self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
            self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
            if not context.executing_eagerly():
                return  # Restore-on-create is only supported when executing eagerly
            on_create_model = MyModel()
            on_create_optimizer = adam.AdamOptimizer(
                0.001,
                # Preserve beta1_power and beta2_power when applying gradients
                # so we can test that they've been restored correctly.
                beta1=1.0,
                beta2=1.0)
            on_create_root = trackable_utils.Checkpoint(
                optimizer=on_create_optimizer, model=on_create_model)
            # Deferred restoration
            status = on_create_root.restore(save_path=save_path)
            status.assert_nontrivial_match()
            status.assert_existing_objects_matched()
            with self.assertRaises(AssertionError):
                status.assert_consumed()
            on_create_model(constant_op.constant([[3.]]))  # create variables
            self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
            self.assertAllEqual([42.],
                                self.evaluate(
                                    on_create_model._named_dense.variables[1]))
            on_create_m_bias_slot = on_create_optimizer.get_slot(
                on_create_model._named_dense.variables[1], "m")
            status.assert_existing_objects_matched()
            with self.assertRaises(AssertionError):
                status.assert_consumed()
            # Optimizer slot variables are created when the original variable is
            # restored.
            self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
            self.assertAllEqual(optimizer_variables[2:],
                                self.evaluate(on_create_optimizer.variables()))
            dummy_var = variables.Variable([1.])
            on_create_optimizer.minimize(loss=dummy_var.read_value)
            status.assert_existing_objects_matched()
            status.assert_consumed()
            beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators(
            )
            self.assertAllEqual(optimizer_variables[0],
                                self.evaluate(beta1_power))
            self.assertAllEqual(optimizer_variables[1],
                                self.evaluate(beta2_power))

    # TODO(allenl): Debug garbage created by this test in python3.
    def testDeferredRestorationUsageEager(self):
        """An idiomatic eager execution example."""
        num_training_steps = 10
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        for training_continuation in range(3):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root = trackable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                optimizer_step=training_util.get_or_create_global_step())
            root.restore(
                checkpoint_management.latest_checkpoint(checkpoint_directory))
            for _ in range(num_training_steps):
                # TODO(allenl): Use a Dataset and serialize/checkpoint it.
                input_value = constant_op.constant([[3.]])
                optimizer.minimize(
                    lambda: model(input_value),  # pylint: disable=cell-var-from-loop
                    global_step=root.optimizer_step)
            root.save(file_prefix=checkpoint_prefix)
            self.assertEqual((training_continuation + 1) * num_training_steps,
                             root.optimizer_step.numpy())

    def testEagerDistributionStrategy(self):
        num_training_steps = 10
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

        def _train_fn(optimizer, model):
            input_value = constant_op.constant([[3.]])
            optimizer.minimize(functools.partial(model, input_value),
                               global_step=root.optimizer_step)

        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
            for training_continuation in range(3):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001)
                root = trackable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    optimizer_step=training_util.get_or_create_global_step())
                root.restore(
                    checkpoint_management.latest_checkpoint(
                        checkpoint_directory))

                for _ in range(num_training_steps):
                    strategy.extended.call_for_each_replica(
                        functools.partial(_train_fn, optimizer, model))
                root.save(file_prefix=checkpoint_prefix)
                self.assertEqual(
                    (training_continuation + 1) * num_training_steps,
                    root.optimizer_step.numpy())

    def testGraphDistributionStrategy(self):
        self.skipTest("b/121381184")
        num_training_steps = 10
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

        def _train_fn(optimizer, model):
            input_value = constant_op.constant([[3.]])
            return optimizer.minimize(functools.partial(model, input_value),
                                      global_step=root.optimizer_step)

        for training_continuation in range(3):
            with ops.Graph().as_default():
                strategy = mirrored_strategy.MirroredStrategy()
                with strategy.scope():
                    model = MyModel()
                    optimizer = adam.AdamOptimizer(0.001)
                    root = trackable_utils.Checkpoint(
                        optimizer=optimizer,
                        model=model,
                        optimizer_step=training_util.get_or_create_global_step(
                        ))
                    status = root.restore(
                        checkpoint_management.latest_checkpoint(
                            checkpoint_directory))
                    train_op = strategy.extended.call_for_each_replica(
                        functools.partial(_train_fn, optimizer, model))
                    with self.session() as session:
                        if training_continuation > 0:
                            status.assert_consumed()
                        status.initialize_or_restore()
                        for _ in range(num_training_steps):
                            session.run(train_op)
                        root.save(file_prefix=checkpoint_prefix)
                self.assertEqual(
                    (training_continuation + 1) * num_training_steps,
                    root.optimizer_step.numpy())

    def testUsageGraph(self):
        """Expected usage when graph building."""
        with context.graph_mode():
            num_training_steps = 10
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            for training_continuation in range(3):
                with ops.Graph().as_default():
                    model = MyModel()
                    optimizer = adam.AdamOptimizer(0.001)
                    root = trackable_utils.CheckpointV1(
                        optimizer=optimizer,
                        model=model,
                        global_step=training_util.get_or_create_global_step())
                    input_value = constant_op.constant([[3.]])
                    train_op = optimizer.minimize(model(input_value),
                                                  global_step=root.global_step)
                    checkpoint_path = checkpoint_management.latest_checkpoint(
                        checkpoint_directory)
                    with self.session(
                            graph=ops.get_default_graph()) as session:
                        status = root.restore(save_path=checkpoint_path)
                        status.initialize_or_restore(session=session)
                        if checkpoint_path is None:
                            self.assertEqual(0, training_continuation)
                            with self.assertRaises(AssertionError):
                                status.assert_consumed()
                            with self.assertRaises(AssertionError):
                                status.assert_existing_objects_matched()
                        else:
                            status.assert_consumed()
                            status.assert_existing_objects_matched()
                        for _ in range(num_training_steps):
                            session.run(train_op)
                        root.save(file_prefix=checkpoint_prefix,
                                  session=session)
                        self.assertEqual(
                            (training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
                        self.assertEqual(training_continuation + 1,
                                         session.run(root.save_counter))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testAgnosticUsage(self):
        """Graph/eager agnostic usage."""
        # Does create garbage when executing eagerly due to ops.Graph() creation.
        with self.test_session():
            num_training_steps = 10
            checkpoint_directory = self.get_temp_dir()
            for training_continuation in range(3):
                with testing_utils.device(should_use_gpu=True):
                    model = MyModel()
                    optimizer = adam.AdamOptimizer(0.001)
                    root = trackable_utils.Checkpoint(
                        optimizer=optimizer,
                        model=model,
                        global_step=training_util.get_or_create_global_step())
                    manager = checkpoint_management.CheckpointManager(
                        root, checkpoint_directory, max_to_keep=1)
                    status = root.restore(save_path=manager.latest_checkpoint)
                    input_value = constant_op.constant([[3.]])
                    train_fn = functools.partial(optimizer.minimize,
                                                 functools.partial(
                                                     model, input_value),
                                                 global_step=root.global_step)
                    if not context.executing_eagerly():
                        train_fn = functools.partial(self.evaluate, train_fn())
                    status.initialize_or_restore()
                    for _ in range(num_training_steps):
                        train_fn()
                    manager.save()
                    self.assertEqual(
                        (training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
                    self.assertEqual(training_continuation + 1,
                                     self.evaluate(root.save_counter))

    # pylint: disable=cell-var-from-loop
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testWithDefun(self):
        with self.test_session():
            num_training_steps = 2
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            for training_continuation in range(3):
                with testing_utils.device(should_use_gpu=True):
                    model = MyModel()
                    # Don't actually train so we can test variable values
                    optimizer = adam.AdamOptimizer(0.)
                    root = trackable_utils.Checkpoint(
                        optimizer=optimizer,
                        model=model,
                        global_step=training_util.get_or_create_global_step())
                    checkpoint_path = checkpoint_management.latest_checkpoint(
                        checkpoint_directory)
                    status = root.restore(save_path=checkpoint_path)

                    def train_fn():
                        @def_function.function
                        def _call_model(x):
                            return model(x)

                        with backprop.GradientTape() as tape:
                            loss = _call_model(constant_op.constant([[3.]]))
                        gradients = tape.gradient(loss, model.variables)
                        return optimizer.apply_gradients(
                            zip(gradients, model.variables),
                            global_step=root.global_step)

                    if not context.executing_eagerly():
                        train_fn = functools.partial(self.evaluate, train_fn())
                    status.initialize_or_restore()
                    for _ in range(num_training_steps):
                        train_fn()
                    if training_continuation > 0:
                        status.assert_consumed()
                        self.assertAllClose([[42.]],
                                            self.evaluate(model.variables[0]))
                    else:
                        self.evaluate(model.variables[0].assign([[42.]]))
                    root.save(file_prefix=checkpoint_prefix)
                    self.assertEqual(
                        (training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
                    self.assertEqual(training_continuation + 1,
                                     self.evaluate(root.save_counter))

    # pylint: enable=cell-var-from-loop

    def _get_checkpoint_name(self, name):
        root = module.Module()
        trackable_utils.add_variable(root,
                                     name=name,
                                     shape=[1, 2],
                                     dtype=dtypes.float64)
        (named_variable, ), _, _ = trackable_utils._serialize_object_graph(
            root, saveables_cache=None)
        with ops.name_scope_v2("root/" + named_variable.name):
            pass  # Make sure we can use this as an op name if we prefix it.
        return named_variable.name

    def testAnonymousVarsInInit(self):
        class Model(training.Model):
            def __init__(self):
                super(Model, self).__init__()
                self.w = variables.Variable(0.0)
                self.b = variables.Variable(0.0)
                self.vars = [self.w, self.b]

            def call(self, x):
                return x * self.w + self.b

        with context.eager_mode():
            model = Model()
            optimizer = adam.AdamOptimizer(learning_rate=0.05)
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            checkpoint = trackable_utils.Checkpoint(model=model,
                                                    optimizer=optimizer)
            for _ in range(2):
                checkpoint.save(checkpoint_prefix)
                with backprop.GradientTape() as tape:
                    loss = (constant_op.constant(1.) -
                            model(constant_op.constant(1.)))**2
                grad = tape.gradient(loss, model.vars)
                optimizer.apply_gradients([(g, v)
                                           for g, v in zip(grad, model.vars)])

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def test_initialize_if_not_restoring(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001)
                root = trackable_utils.Checkpoint(
                    model=
                    model,  # Do not save the optimizer with the checkpoint.
                    global_step=training_util.get_or_create_global_step())
                optimizer_checkpoint = trackable_utils.Checkpoint(
                    optimizer=optimizer)

                checkpoint_path = checkpoint_management.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                self.evaluate([v.initializer for v in optimizer.variables()])
                train_fn()
                model_save_path = root.save(file_prefix=checkpoint_prefix)
                self.evaluate(optimizer.variables()[0].assign(42.))
                optimizer_save_path = optimizer_checkpoint.save(
                    optimizer_only_prefix)

            # Restore into a graph with the optimizer
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001)
                root = trackable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                status = root.restore(save_path=model_save_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                train_fn()
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()
                with self.assertRaises(AssertionError):
                    status.assert_consumed()

            # Make sure initialization doesn't clobber later restores
            with testing_utils.device(should_use_gpu=True):
                model = MyModel()
                optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
                root = trackable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
                status = root.restore(save_path=model_save_path)
                init_only_optimizer_status = opt_root.restore(save_path=None)
                optimizer_status = opt_root.restore(
                    save_path=optimizer_save_path)
                input_value = constant_op.constant([[3.]])
                train_fn = functools.partial(optimizer.minimize,
                                             functools.partial(
                                                 model, input_value),
                                             global_step=root.global_step)
                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                optimizer_status.run_restore_ops()
                status.initialize_or_restore()
                init_only_optimizer_status.initialize_or_restore()
                train_fn()
                self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
Beispiel #14
0
class CheckpointCompatibilityTests(keras_parameterized.TestCase):
    def _initialized_model(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        optimizer = adam.AdamOptimizer(0.001)
        optimizer_step = training_util.get_or_create_global_step()
        root_trackable = trackable_utils.Checkpoint(
            optimizer=optimizer, model=model, optimizer_step=optimizer_step)
        train_op = optimizer.minimize(functools.partial(model, input_value),
                                      global_step=optimizer_step)
        self.evaluate(trackable_utils.gather_initializers(root_trackable))
        self.evaluate(train_op)
        # A regular variable, a slot variable, and a non-slot Optimizer variable
        # with known values to check when loading.
        self.evaluate(model._named_dense.bias.assign([1.]))
        self.evaluate(
            optimizer.get_slot(var=model._named_dense.bias,
                               name="m").assign([2.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(3.))
        return root_trackable

    def _set_sentinels(self, root_trackable):
        self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
        self.evaluate(
            root_trackable.optimizer.get_slot(
                var=root_trackable.model._named_dense.bias,
                name="m").assign([102.]))
        beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(103.))

    def _check_sentinels(self, root_trackable):
        self.assertAllEqual([1.],
                            self.evaluate(
                                root_trackable.model._named_dense.bias))
        self.assertAllEqual([2.],
                            self.evaluate(
                                root_trackable.optimizer.get_slot(
                                    var=root_trackable.model._named_dense.bias,
                                    name="m")))
        beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
        self.assertAllEqual(3., self.evaluate(beta1_power))

    def _write_name_based_checkpoint(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(
                    graph=save_graph) as session:
                root = self._initialized_model()
                name_saver = saver_lib.Saver()
                return name_saver.save(sess=session,
                                       save_path=checkpoint_prefix,
                                       global_step=root.optimizer_step)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testLoadFromNameBasedSaver(self):
        """Save a name-based checkpoint, load it using the object-based API."""
        with testing_utils.device(should_use_gpu=True):
            with self.test_session():
                save_path = self._write_name_based_checkpoint()
                root = self._initialized_model()
                self._set_sentinels(root)
                with self.assertRaises(AssertionError):
                    self._check_sentinels(root)
                object_saver = trackable_utils.TrackableSaver(
                    graph_view.ObjectGraphView(root))
                self._set_sentinels(root)
                status = object_saver.restore(save_path)
                if context.executing_eagerly():
                    self._check_sentinels(root)
                if context.executing_eagerly():
                    status.assert_consumed()
                    status.assert_existing_objects_matched()
                    status.assert_nontrivial_match()
                else:
                    # When graph building, we haven't read any keys, so we don't know
                    # whether the restore will be complete.
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_consumed()
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_existing_objects_matched()
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_nontrivial_match()
                status.run_restore_ops()
                self._check_sentinels(root)
                self._set_sentinels(root)
                status = object_saver.restore(save_path)
                status.initialize_or_restore()
                self._check_sentinels(root)
                # Check that there is no error when keys are missing from the name-based
                # checkpoint.
                root.not_in_name_checkpoint = variables.Variable([1.])
                status = object_saver.restore(save_path)
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()

    def testSaveGraphLoadEager(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(graph=save_graph):
                root = self._initialized_model()
                save_path = root.save(file_prefix=checkpoint_prefix)
        with context.eager_mode():
            root = self._initialized_model()
            self._set_sentinels(root)
            root.restore(save_path).assert_consumed()
            self._check_sentinels(root)

    def testSaveEagerLoadGraph(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.eager_mode():
            root = self._initialized_model()
            save_path = root.save(file_prefix=checkpoint_prefix)
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(graph=save_graph):
                root = self._initialized_model()
                self._set_sentinels(root)
                root.restore(save_path).assert_consumed().run_restore_ops()
                self._check_sentinels(root)
import tensorflow as tf

from absl.testing import parameterized
from tensorflow.python.keras import combinations

from tensorflow_manopt.manifolds.test_invariants import (
    TestInvariants,
    random_constant,
)
from tensorflow_manopt.manifolds.poincare import Poincare


@combinations.generate(
    combinations.combine(
        mode=["graph", "eager"],
        manifold=[Poincare(), Poincare(k=5.0)],
        shape=[(5, ), (2, 5)],
        dtype=[tf.float32, tf.float64],
    ))
class PoincareTest(tf.test.TestCase, parameterized.TestCase):
    test_random = TestInvariants.check_random

    test_dist = TestInvariants.check_dist

    test_inner = TestInvariants.check_inner

    test_proj = TestInvariants.check_proj

    test_exp_log_inverse = TestInvariants.check_exp_log_inverse

    test_transp_retr = TestInvariants.check_transp_retr
Beispiel #16
0
class KerasModelTest(keras_parameterized.TestCase):
    """Test mixed precision with Keras models."""
    def _skip_if_strategy_unsupported(self, strategy_fn):
        if (strategy_fn != default_strategy_fn
                and testing_utils.get_model_type() == 'subclass'):
            self.skipTest(
                'Non-default strategies are unsupported with subclassed '
                'models')

    def _skip_if_save_format_unsupported(self, save_format):
        model_type = testing_utils.get_model_type()
        if save_format == 'h5' and model_type == 'subclass':
            self.skipTest('Saving subclassed models with the HDF5 format is '
                          'unsupported')
        if (save_format == 'tf' and model_type == 'subclass'
                and not context.executing_eagerly()):
            self.skipTest(
                'b/148820505: This combination of features is currently '
                'broken.')

    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'operator',
            'strategy_fn': create_mirrored_strategy,
            'use_operator': True
        }, {
            'testcase_name': 'regularizer',
            'strategy_fn': create_mirrored_strategy,
            'use_regularizer': True
        }, {
            'testcase_name': 'get_config',
            'strategy_fn': create_mirrored_strategy,
            'get_config': True,
            'use_regularizer': True,
        }, {
            'testcase_name': 'saved_model',
            'strategy_fn': default_strategy_fn,
            'save_format': 'tf',
            'use_regularizer': True,
        }, {
            'testcase_name': 'saved_model_input_spec',
            'strategy_fn': default_strategy_fn,
            'save_format': 'tf',
            'use_regularizer': True,
            'use_input_spec': True,
        }, {
            'testcase_name': 'h5',
            'strategy_fn': default_strategy_fn,
            'save_format': 'h5',
            'use_regularizer': True,
        }, {
            'testcase_name': 'saved_model_distribute',
            'strategy_fn': create_mirrored_strategy,
            'save_format': 'tf',
            'use_regularizer': True,
        }, {
            'testcase_name': 'saved_model_input_spec_distribute',
            'strategy_fn': create_mirrored_strategy,
            'save_format': 'tf',
            'use_regularizer': True,
            'use_input_spec': True,
        }, {
            'testcase_name': 'h5_distribute',
            'strategy_fn': create_mirrored_strategy,
            'save_format': 'h5',
            'use_regularizer': True,
        }, {
            'testcase_name': 'saved_model_v1_policy',
            'strategy_fn': create_mirrored_strategy,
            'use_v1_policy': True,
            'save_format': 'tf',
        })
    def test_model(self,
                   strategy_fn,
                   use_operator=False,
                   use_regularizer=False,
                   policy_name='mixed_float16',
                   get_config=False,
                   save_format=None,
                   use_input_spec=False,
                   use_v1_policy=False):
        self._skip_if_strategy_unsupported(strategy_fn)
        self._skip_if_save_format_unsupported(save_format)
        if use_regularizer:
            weight_regularizer = mp_test_util.IdentityRegularizer()
            activity_regularizer = mp_test_util.ReduceSumRegularizer()
        else:
            weight_regularizer = activity_regularizer = None
        with strategy_fn().scope():
            cls = policy.PolicyV1 if use_v1_policy else policy.Policy
            with policy.policy_scope(cls(policy_name)):
                layer = mp_test_util.MultiplyLayer(
                    assert_type=dtypes.float16,
                    use_operator=use_operator,
                    regularizer=weight_regularizer,
                    activity_regularizer=activity_regularizer,
                    input_shape=(1, ))
                if use_input_spec:
                    layer.input_spec = input_spec.InputSpec(shape=(None, 1))
                model = testing_utils.get_model_from_layers(
                    [layer], input_shape=(1, ), input_dtype=dtypes.float16)
                if get_config:
                    config = model.get_config()
                    model = model.__class__.from_config(
                        config,
                        custom_objects={
                            'MultiplyLayer': mp_test_util.MultiplyLayer
                        })
                    (layer, ) = (
                        layer for layer in model.layers
                        if isinstance(layer, mp_test_util.MultiplyLayer))

                def loss_fn(y_true, y_pred):
                    del y_true
                    return math_ops.reduce_mean(y_pred)

                # Learning rate is small enough that if applied to a float16 variable,
                # the variable will not change. So this tests the learning rate not
                # applied to a float16 value, but instead the float32 variable.
                opt = gradient_descent.SGD(2**-14)
                # Use a fixed loss scale, as this test will fail if gradients are
                # skipped for a step due to dynamic loss scaling.
                opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                              dynamic=False,
                                                              initial_scale=8)
                model.compile(opt,
                              loss=loss_fn,
                              run_eagerly=testing_utils.should_run_eagerly())

        x = np.ones((2, 1))
        y = np.ones((2, 1))
        dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
        model.fit(dataset)
        # Variable starts at 1, and should have gradient of 2 ** -14 subtracted
        # from it.
        expected = 1 - 2**-14
        if use_regularizer:
            # Weight and activity regularizer each add another 2 ** -14 to the
            # gradient.
            expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), expected)

        if save_format:
            with generic_utils.CustomObjectScope({
                    'MultiplyLayer':
                    mp_test_util.MultiplyLayer,
                    'loss_fn':
                    loss_fn
            }):
                self._test_saving(model, dataset, save_format, use_regularizer)

    def _test_saving(self, model, dataset, save_format, use_regularizer):
        # Save and load model, asserting variable does not change
        save_path = os.path.join(self.get_temp_dir(), 'model')
        model.save(save_path, save_format=save_format)
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        expected = 1 - 2**-14
        if use_regularizer:
            expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), expected)

        # Continue training, and assert variable is correct value
        model.fit(dataset)
        new_expected = expected - 2**-14
        if use_regularizer:
            new_expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), new_expected)

        # Load saved model again, and assert variable is previous value
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        self.assertEqual(backend.eval(layer.v), expected)

        # Ensure various dtype-related aspects of the layer are correct
        self.assertEqual(layer.dtype, 'float32')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'mixed_float16')
        self.assertEqual(layer.v.dtype, 'float32')
        self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')

        # Loading a model always loads with a v2 Policy, even if saved with a
        # PolicyV1.
        self.assertEqual(type(model.dtype_policy), policy.Policy)
        self.assertEqual(layer.get_config()['dtype'], {
            'class_name': 'Policy',
            'config': {
                'name': 'mixed_float16'
            }
        })

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        })
    def test_fixed_loss_scaling(self, strategy_fn):
        # Note: We do not test mixed precision in this method, only loss scaling.
        loss_scale = 8.
        batch_size = 4
        with strategy_fn().scope():
            x = layers.Input(shape=(1, ), batch_size=batch_size)
            layer = mp_test_util.MultiplyLayer()
            y = layer(x)

            # The gradient of 'y' at this point is 1. With loss scaling, the gradient
            # is 'loss_scale'. We divide by the batch size since the loss is averaged
            # across batch elements.
            expected_gradient = loss_scale / batch_size
            identity_with_grad_check_fn = (
                mp_test_util.create_identity_with_grad_check_fn(
                    [expected_gradient]))
            y = core.Lambda(identity_with_grad_check_fn)(y)
            model = models.Model(inputs=x, outputs=y)

            def loss_fn(y_true, y_pred):
                del y_true
                return math_ops.reduce_mean(y_pred)

            opt = gradient_descent.SGD(1.)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, dynamic=False, initial_scale=loss_scale)
            model.compile(opt,
                          loss=loss_fn,
                          run_eagerly=testing_utils.should_run_eagerly())

        self.assertEqual(backend.eval(layer.v), 1)
        x = np.ones((batch_size, 1))
        y = np.ones((batch_size, 1))
        dataset = dataset_ops.Dataset.from_tensor_slices(
            (x, y)).batch(batch_size)
        model.fit(dataset)
        # Variable starts at 1, and should have gradient of 1 subtracted from it.
        expected = 0
        self.assertEqual(backend.eval(layer.v), expected)

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'loss_scaling',
            'strategy_fn': create_mirrored_strategy,
            'use_loss_scaling': True
        })
    def test_advanced_model(self, strategy_fn, use_loss_scaling=False):
        # The advanced model tests mixed-precision-related features that would occur
        # in a resnet50 model. It tests a model that has:
        #  * Multiple layers, some which use auto-cast variables and some which do
        #    not
        #  * Regularization on some variables and not others.
        #  * A fixed loss scale (if use_loss_scaling is True)

        strategy = strategy_fn()
        if use_loss_scaling:
            loss_scale = 8.
        learning_rate = 2**-14

        with strategy.scope():
            with policy.policy_scope(policy.Policy('mixed_float16')):
                x = layers.Input(shape=(1, ), batch_size=2)
                layer1 = mp_test_util.MultiplyLayer(
                    assert_type=dtypes.float16,
                    regularizer=mp_test_util.IdentityRegularizer(),
                    use_operator=True)
                layer2 = mp_test_util.MultiplyLayerWithoutAutoCast(
                    assert_type=dtypes.float16, use_operator=True)
                layer3 = mp_test_util.MultiplyLayer(assert_type=dtypes.float16,
                                                    use_operator=False)
                layer4 = mp_test_util.MultiplyLayerWithoutAutoCast(
                    assert_type=dtypes.float16,
                    regularizer=mp_test_util.IdentityRegularizer(),
                    use_operator=False)
                y = layer1(x)
                y = layer2(y)
                y = layer3(y)
                y = layer4(y)
                if use_loss_scaling:
                    # The gradient of 'y' at this point is 1. With loss scaling, the
                    # gradient is 'loss_scale'. We divide by the batch size of 2 since the
                    # loss is averaged across batch elements.
                    expected_gradient = loss_scale / 2
                    identity_with_grad_check_fn = (
                        mp_test_util.create_identity_with_grad_check_fn(
                            expected_dtype=dtypes.float16,
                            expected_gradient=[expected_gradient]))
                    y = core.Lambda(identity_with_grad_check_fn)(y)
                model = models.Model(inputs=x, outputs=y)

                def loss_fn(y_true, y_pred):
                    del y_true
                    return math_ops.reduce_mean(y_pred)

                opt = gradient_descent.SGD(learning_rate)
                if use_loss_scaling:
                    opt = loss_scale_optimizer.LossScaleOptimizer(
                        opt, dynamic=False, initial_scale=loss_scale)
                model.compile(opt,
                              loss=loss_fn,
                              run_eagerly=testing_utils.should_run_eagerly())

        x = np.ones((2, 1))
        y = np.ones((2, 1))
        dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
        model.fit(dataset)
        for layer in (layer1, layer2, layer3, layer4):
            if layer.losses:
                # Layer has weight regularizer
                self.assertEqual(backend.eval(layer.v), 1 - 2 * learning_rate)
            else:
                # Layer does not have weight regularizer
                self.assertEqual(backend.eval(layer.v), 1 - learning_rate)

    @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'pass_loss_scale_to_policy',
            'strategy_fn': create_mirrored_strategy,
            'pass_loss_scale_to_policy': True,
        }, {
            'testcase_name': 'get_config',
            'strategy_fn': create_mirrored_strategy,
            'get_config': True,
        }, {
            'testcase_name': 'get_config_v1_lso',
            'strategy_fn': create_mirrored_strategy,
            'get_config': True,
            'use_v1_loss_scale_optimizer': True,
        }, {
            'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
            'strategy_fn': create_mirrored_strategy,
            'get_config': True,
            'pass_loss_scale_to_policy': True,
        })
    def test_dynamic_loss_scaling(self,
                                  strategy_fn,
                                  pass_loss_scale_to_policy=False,
                                  get_config=False,
                                  use_v1_loss_scale_optimizer=False):
        strategy = strategy_fn()
        initial_loss_scale = 2.
        batch_size = 4
        expected_gradient = backend.variable([initial_loss_scale / batch_size],
                                             dtype=dtypes.float16)
        # If this variable is set to True, the model below will have NaN gradients
        have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
        with strategy.scope():
            opt = gradient_descent.SGD(1.)
            if pass_loss_scale_to_policy:
                loss_scale = loss_scale_module.DynamicLossScale(
                    initial_loss_scale=initial_loss_scale, increment_period=2)
                p = policy.PolicyV1('mixed_float16', loss_scale=loss_scale)
            elif use_v1_loss_scale_optimizer:
                loss_scale = loss_scale_module.DynamicLossScale(
                    initial_loss_scale=initial_loss_scale, increment_period=2)
                p = policy.Policy('mixed_float16')
                opt = loss_scale_optimizer.LossScaleOptimizerV1(
                    opt, loss_scale)
            else:
                p = policy.Policy('mixed_float16')
                opt = loss_scale_optimizer.LossScaleOptimizer(
                    opt,
                    initial_scale=initial_loss_scale,
                    dynamic_growth_steps=2)
            with policy.policy_scope(p):
                x = layers.Input(shape=(1, ),
                                 batch_size=batch_size,
                                 dtype=dtypes.float16)
                layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
                y = layer(x)
                identity_with_nan_grads = (
                    mp_test_util.create_identity_with_nan_gradients_fn(
                        have_nan_gradients))
                y = core.Lambda(identity_with_nan_grads)(y)
                identity_with_grad_check_fn = (
                    mp_test_util.create_identity_with_grad_check_fn(
                        expected_dtype=dtypes.float16,
                        expected_gradient=expected_gradient))
                y = core.Lambda(identity_with_grad_check_fn)(y)
                model = models.Model(inputs=x, outputs=y)
                if get_config:
                    config = model.get_config()
                    model = model.__class__.from_config(
                        config,
                        custom_objects={
                            'MultiplyLayer': mp_test_util.MultiplyLayer
                        })
                    (layer, ) = (
                        layer for layer in model.layers
                        if isinstance(layer, mp_test_util.MultiplyLayer))

                def loss_fn(y_true, y_pred):
                    del y_true
                    return math_ops.reduce_mean(y_pred)

                model.compile(opt,
                              loss=loss_fn,
                              run_eagerly=testing_utils.should_run_eagerly())

        self.assertEqual(backend.eval(layer.v), 1)
        x = np.ones((batch_size, 1))
        y = np.ones((batch_size, 1))
        dataset = dataset_ops.Dataset.from_tensor_slices(
            (x, y)).batch(batch_size)
        model.fit(dataset)
        # The variables starts with 1 and has a gradient of 1, so will go down by 1
        # each step.
        self.assertEqual(backend.eval(layer.v), 0)

        model.fit(dataset)
        self.assertEqual(backend.eval(layer.v), -1)

        # There have been two steps without NaNs, so the loss scale will double
        backend.set_value(expected_gradient,
                          backend.get_value(expected_gradient * 2))
        model.fit(dataset)
        self.assertEqual(backend.eval(layer.v), -2)

        # Next test with NaN gradients.
        backend.set_value(have_nan_gradients, True)
        model.fit(dataset)
        # Variable should not be updated
        self.assertEqual(backend.eval(layer.v), -2)

        # Test with finite gradients again
        backend.set_value(have_nan_gradients, False)
        # The loss scale will be halved due to the NaNs, so the gradient will also
        # be halved
        backend.set_value(expected_gradient,
                          backend.get_value(expected_gradient / 2))
        model.fit(dataset)
        self.assertEqual(backend.eval(layer.v), -3)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_loss_scale_optimizer_overrides_policy_v1_loss_scale(self):
        with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
            opt = gradient_descent.SGD(1.)
            opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                          dynamic=False,
                                                          initial_scale=5.)
            x = layers.Input(shape=(1, ))
            y = mp_test_util.MultiplyLayer()(x)
            model = models.Model(x, y)
            model.compile(opt, loss='mse')
            self.assertEqual(self.evaluate(model.optimizer.loss_scale), 5.)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_policy_v1_without_loss_scale(self):
        with policy.policy_scope(
                policy.PolicyV1('mixed_float16', loss_scale=None)):
            opt = gradient_descent.SGD(1.)
            x = layers.Input(shape=(1, ))
            y = mp_test_util.MultiplyLayer()(x)
            model = models.Model(x, y)
            model.compile(opt, loss='mse')
            self.assertNotIsInstance(model.optimizer,
                                     loss_scale_optimizer.LossScaleOptimizer)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_pass_invalid_optimizer_with_loss_scaling(self):
        with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
            x = layers.Input(shape=(1, ))
            y = mp_test_util.MultiplyLayer()(x)
            model = models.Model(x, y)
            if context.executing_eagerly():
                error_msg = 'Use a `tf.keras` Optimizer instead'
            else:
                error_msg = 'optimizer" must be an instance of '
            with self.assertRaisesRegex(ValueError, error_msg):
                model.compile(optimizer_v1.SGD(1.), 'mse')

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_functional_model_loss_dtype(self):
        with policy.policy_scope('float16'):
            x = layers.Input(shape=(1, ))
            y = mp_test_util.MultiplyLayer()(x)
            model = models.Model(x, y)
            model.add_loss(math_ops.cast(y, 'float32'))
            # The loss should not be casted to the policy's dtype.
            self.assertEqual(model.losses[0].dtype, 'float32')

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn,
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'base_h5',
            'strategy_fn': default_strategy_fn,
            'h5': True,
        }, {
            'testcase_name': 'distribute_h5',
            'strategy_fn': create_mirrored_strategy,
            'h5': True,
        })
    def test_save_weights_with_autocast_vars(self, strategy_fn, h5=False):
        with strategy_fn().scope():
            with policy.policy_scope('mixed_float16'):
                x = layers.Input(shape=(1, ), batch_size=2)
                layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
                y = layer(x)
                model = models.Model(inputs=x, outputs=y)

        model.set_weights([np.array(100.)])
        x = np.ones((2, 1))
        self.assertAllClose(backend.get_value(model(x)), x * 100.)
        suffix = '.h5' if h5 else ''
        weights_file = os.path.join(self.get_temp_dir(), 'weights' + suffix)
        model.save_weights(weights_file)

        model.set_weights([np.array(200.)])
        self.assertAllClose(backend.get_value(model(x)), x * 200.)
        model.load_weights(weights_file)
        self.assertAllClose(backend.get_value(model(x)), x * 100.)
        self.assertEqual(model.get_weights(), [np.array(100.)])

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn,
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'different_var_name',
            'strategy_fn': default_strategy_fn,
            'var_name': 'w'
        }, {
            'testcase_name': 'different_var_name_distribute',
            'strategy_fn': create_mirrored_strategy,
            'var_name': 'w'
        })
    def test_save_slot_variables_with_autocast_vars(self,
                                                    strategy_fn,
                                                    var_name='v'):
        p = policy.Policy('mixed_float16')
        with strategy_fn().scope(), policy.policy_scope(p):
            x = layers.Input(shape=(2, ), batch_size=2)
            # Having a var_name other than 'v' tests that a fixed bug (b/134713714)
            # does not reoccur. The bug was that a crash would occur when saving a
            # checkpoint where an AutoCastVariable with a slot variable would have a
            # different name than the layer attribute's name (layer.v in this case).
            layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16,
                                               var_name=var_name)
            y = layer(x)
            model = models.Model(inputs=x, outputs=y)
            opt = gradient_descent.SGD(1., 1.)
            opt = loss_scale_optimizer.LossScaleOptimizer(opt,
                                                          dynamic=False,
                                                          initial_scale=1)
            model.compile(optimizer=opt,
                          loss='mse',
                          run_eagerly=testing_utils.should_run_eagerly())

        model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
        weights_file = os.path.join(self.get_temp_dir(), 'weights')
        model.save_weights(weights_file)
        saved_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))

        model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
        new_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
        self.assertNotEqual(new_slot, saved_slot)

        model.load_weights(weights_file)
        restored_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
        self.assertEqual(restored_slot, saved_slot)

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(*TESTCASES)
    def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
        strategy = strategy_fn()
        if (isinstance(strategy, mirrored_strategy.MirroredStrategy)
                and not context.executing_eagerly()):
            # TODO(b/121381184): Enable running the test in this case.
            return

        # Create and run model.
        with strategy.scope():
            x = layers.Input(shape=(2, ), batch_size=2, dtype=dtypes.float32)
            y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x)
            model = models.Model(inputs=x, outputs=y)

            opt = gradient_descent.SGD(1.)
            opt = loss_scale_optimizer.LossScaleOptimizer(
                opt, initial_scale=1., dynamic_growth_steps=2.)
            model.compile(optimizer=opt,
                          loss='mse',
                          run_eagerly=testing_utils.should_run_eagerly())
        # Run for 3 steps (6 examples with a batch size of 2)
        model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2)
        self.assertEqual(backend.get_value(opt.loss_scale), 2)
        self.assertEqual(backend.get_value(opt.dynamic_counter), 1)

        # Save model weights.
        save_prefix = os.path.join(self.get_temp_dir(), 'ckpt')
        model.save_weights(save_prefix)

        # Run model again for 1 step (2 examples with a batch size of 2)
        model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2)
        self.assertEqual(backend.get_value(opt.loss_scale), 4)
        self.assertEqual(backend.get_value(opt.dynamic_counter), 0)

        # Load model weights and ensure loss scale weights are restored.
        model.load_weights(save_prefix)
        self.assertEqual(backend.get_value(opt.loss_scale), 2)
        self.assertEqual(backend.get_value(opt.dynamic_counter), 1)

    @keras_parameterized.run_all_keras_modes
    def test_restore_old_loss_scale_checkpoint(self):
        # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
        # of LossScaleOptimizer changed, but old checkpoints can still be loaded
        opt = gradient_descent.SGD(0.1, momentum=0.1)
        opt = loss_scale_optimizer.LossScaleOptimizer(opt)
        model = sequential.Sequential([core.Dense(2, )])

        # The checkpoint and expected values were obtained from the program in
        # testdata/BUILD.
        ckpt_dir = os.path.join(flags.FLAGS['test_srcdir'].value,
                                'org_tensorflow/tensorflow/python/keras',
                                'mixed_precision/testdata/lso_ckpt_tf2.2')
        # ckpt_dir = test.test_src_dir_path(
        #     'python/keras/mixed_precision/testdata/lso_ckpt_tf2.2')
        model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
        model.compile(opt,
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model(np.zeros((2, 2)))  # Create model weights
        opt._create_all_weights(model.weights)
        expected_kernel = np.array([[9.229685, 10.901115],
                                    [10.370763, 9.757362]])
        expected_slot = np.array([[10.049943, 9.917691], [10.049943,
                                                          9.917691]])
        self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
        self.assertAllClose(
            self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
            expected_slot)
        self.assertEqual(self.evaluate(opt.loss_scale), 32768)
        self.assertEqual(self.evaluate(opt.dynamic_counter), 1)

        # Check restoring works even after the model is compiled and the weights
        # have been created.
        model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
        self.assertNotAllClose(self.evaluate(model.weights[0]),
                               expected_kernel)
        self.assertNotAllClose(
            self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
            expected_slot)
        model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
        self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
        self.assertAllClose(
            self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
            expected_slot)
        self.assertEqual(self.evaluate(opt.loss_scale), 32768)
        self.assertEqual(self.evaluate(opt.dynamic_counter), 1)

    def test_restore_old_saved_model(self):
        saved_model_dir = os.path.join(
            flags.FLAGS['test_srcdir'].value,
            'org_tensorflow/tensorflow/python/keras',
            'mixed_precision/testdata/lso_savedmodel_tf2.2')
        # saved_model_dir = test.test_src_dir_path(
        #     'python/keras/mixed_precision/testdata/'
        #     'lso_savedmodel_tf2.2')
        model = save.load_model(saved_model_dir)
        expected_kernel = np.array([[9.229685, 10.901115],
                                    [10.370763, 9.757362]])
        self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
        self.assertEqual(type(model.optimizer),
                         loss_scale_optimizer.LossScaleOptimizer)

    @keras_parameterized.run_all_keras_modes
    @parameterized.named_parameters(
        {
            'testcase_name': 'base',
            'strategy_fn': default_strategy_fn,
        }, {
            'testcase_name': 'distribute',
            'strategy_fn': create_mirrored_strategy,
        }, {
            'testcase_name': 'use_v1_lso',
            'strategy_fn': create_mirrored_strategy,
            'use_v1_loss_scale_optimizer': True
        }, {
            'testcase_name': 'base_h5',
            'strategy_fn': default_strategy_fn,
            'h5': True,
        }, {
            'testcase_name': 'distribute_h5',
            'strategy_fn': create_mirrored_strategy,
            'h5': True,
        })
    def test_save_model_with_dynamic_loss_scaling(
            self, strategy_fn, h5=False, use_v1_loss_scale_optimizer=False):
        # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
        # as well.
        strategy = strategy_fn()
        if (isinstance(strategy, mirrored_strategy.MirroredStrategy)
                and not context.executing_eagerly()):
            # TODO(b/121381184): Enable running the test in this case.
            return

        # Create and run model.
        with strategy.scope():
            x = layers.Input(shape=(2, ), batch_size=2, dtype=dtypes.float32)
            y = mp_test_util.MultiplyLayer()(x)
            model = models.Model(inputs=x, outputs=y)

            opt = gradient_descent.SGD(1.)
            if use_v1_loss_scale_optimizer:
                loss_scale = loss_scale_module.DynamicLossScale(
                    initial_loss_scale=1., increment_period=2.)
                opt = loss_scale_optimizer.LossScaleOptimizerV1(
                    opt, loss_scale)
            else:
                opt = loss_scale_optimizer.LossScaleOptimizer(
                    opt, initial_scale=1., dynamic_growth_steps=2.)
            model.compile(optimizer=opt,
                          loss='mse',
                          run_eagerly=testing_utils.should_run_eagerly())
        # Run for 3 steps (6 examples with a batch size of 2)
        model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2)
        self.assertEqual(backend.get_value(opt.loss_scale), 2)
        self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
        (weight, ) = model.trainable_weights
        orig_weight = backend.get_value(weight)

        # Save model weights.
        save_path = os.path.join(self.get_temp_dir(), 'model')
        model.save(save_path, save_format='h5' if h5 else 'tf')

        # Run model again for 1 step (2 examples with a batch size of 2)
        model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
        new_weight = backend.get_value(weight)
        self.assertNotEqual(new_weight, orig_weight)
        self.assertEqual(backend.get_value(opt.loss_scale), 4)
        self.assertEqual(backend.get_value(opt.dynamic_counter), 0)

        # Load model weights and ensure loss scale weights are restored.
        model = save.load_model(
            save_path,
            custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer})
        (weight, ) = model.trainable_weights
        loaded_weight = backend.get_value(weight)
        self.assertEqual(loaded_weight, orig_weight)
        # Currently the loss scale isn't always saved when the model is saved with
        # Model.save(). So we assert the loss scale either has the value when it was
        # saved, or the value it was initialized with.
        # TODO(reedwm): Always save/restore the loss scale with Model.save().
        self.assertIn(backend.get_value(model.optimizer.loss_scale), (1, 2))
        self.assertIn(backend.get_value(model.optimizer.dynamic_counter),
                      (0, 1))

        # Test optimizer attributes and type
        self.assertEqual(model.optimizer.initial_scale, 1.)
        self.assertEqual(model.optimizer.dynamic_growth_steps, 2.)
        self.assertEqual(type(model.optimizer),
                         loss_scale_optimizer.LossScaleOptimizer)
Beispiel #17
0
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
    def test_functions_have_same_trace(self):
        class Layer(keras.engine.base_layer.Layer):
            def call(self, inputs):
                return inputs

            def call2(self, inputs):
                return inputs * 2

        layer = Layer()
        call_collection = keras_save.LayerCallCollection(layer)
        fn = call_collection.add_function(layer.call, 'call')
        fn2 = call_collection.add_function(layer.call2, 'call2')

        fn(np.ones((2, 3)))
        fn(np.ones((4, 5)))

        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
        self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2)

        # Check that the shapes are correct
        self.assertEqual(
            {(2, 3), (4, 5)},
            set(
                tuple(c.structured_input_signature[0][0].shape.as_list())
                for c in fn2._list_all_concrete_functions_for_serialization()))

    def test_training_arg_replacement(self):
        def assert_num_traces(layer_cls, training_keyword):
            layer = layer_cls()
            call_collection = keras_save.LayerCallCollection(layer)
            fn = call_collection.add_function(layer.call, 'call')

            fn(np.ones((2, 3)), training=True)
            self.assertLen(fn._list_all_concrete_functions_for_serialization(),
                           2)

            fn(np.ones((2, 4)), training=False)
            self.assertLen(fn._list_all_concrete_functions_for_serialization(),
                           4)

            if training_keyword:
                fn(np.ones((2, 5)), True)
                self.assertLen(
                    fn._list_all_concrete_functions_for_serialization(), 6)
                fn(np.ones((2, 6)))
                self.assertLen(
                    fn._list_all_concrete_functions_for_serialization(), 8)

        class LayerWithTrainingKeyword(keras.engine.base_layer.Layer):
            def call(self, inputs, training=False):
                return inputs * training

        assert_num_traces(LayerWithTrainingKeyword, training_keyword=True)

        class LayerWithKwargs(keras.engine.base_layer.Layer):
            def call(self, inputs, **kwargs):
                return inputs * kwargs['training']

        assert_num_traces(LayerWithKwargs, training_keyword=False)

        class LayerWithChildLayer(keras.engine.base_layer.Layer):
            def __init__(self):
                self.child = LayerWithKwargs()
                super(LayerWithChildLayer, self).__init__()

            def call(self, inputs):
                return self.child(inputs)

        assert_num_traces(LayerWithChildLayer, training_keyword=False)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_maintains_losses(self):
        layer = LayerWithLoss()
        layer(np.ones((2, 3)))
        previous_losses = layer.losses[:]

        call_collection = keras_save.LayerCallCollection(layer)
        fn = call_collection.add_function(layer.call, 'call')
        fn(np.ones((2, 3)))

        self.assertAllEqual(previous_losses, layer.losses)
Beispiel #18
0
class MappingTests(keras_parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testTracking(self):
        with self.test_session():
            model = HasMapping()
            output = model(array_ops.ones([32, 2]))
            self.assertAllEqual([32, 7], output.shape.as_list())
            self.assertEqual(5, len(model.layers))
            self.assertEqual(len(model.layers), len(model.layer_dict.layers))
            self.assertEqual(1, len(model._checkpoint_dependencies))
            self.assertIs(model.layer_dict,
                          model._checkpoint_dependencies[0].ref)
            self.evaluate([v.initializer for v in model.variables])
            test_var = model.layer_dict["output"].kernel
            self.evaluate(test_var.assign(array_ops.ones([6, 7])))
            save_path = os.path.join(self.get_temp_dir(), "ckpt")
            model.save_weights(save_path)
            self.evaluate(test_var.assign(array_ops.zeros([6, 7])))
            model.load_weights(save_path)
            self.assertAllEqual(numpy.ones([6, 7]), self.evaluate(test_var))

    def testLayerCollectionWithExternalMutation(self):
        d = {}
        root = module.Module()
        root.wrapper = d
        self.assertEqual([], root.wrapper.layers)
        self.assertEqual([], root.wrapper.trainable_weights)
        layer1 = core.Dense(1)
        layer2 = core.Dense(1)
        d["a"] = layer1
        d["b"] = layer2
        self.assertEqual([layer1, layer2], root.wrapper.layers)
        # The layers have still not created variables
        self.assertEqual([], root.wrapper.trainable_weights)

    def testDictWrapperBadKeys(self):
        a = module.Module()
        a.d = {}
        a.d[1] = data_structures.wrap_or_unwrap([])
        model = training.Model()
        model.sub = a
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        with self.assertRaisesRegex(ValueError, "non-string key"):
            model.save_weights(save_path)

    def testDictWrapperNoDependency(self):
        a = module.Module()
        a.d = data_structures.NoDependency({})
        a.d[1] = [3]
        self.assertEqual([a], util.list_objects(a))
        model = training.Model()
        model.sub = a
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        model.save_weights(save_path)
        model.load_weights(save_path)

    def testNonStringKeyNotTrackableValue(self):
        a = module.Module()
        a.d = {}
        a.d["a"] = [3]
        a.d[1] = data_structures.NoDependency([3])
        self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
        model = training.Model()
        model.sub = a
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        model.save_weights(save_path)
        model.load_weights(save_path)

    def testNonAppendNotTrackable(self):
        # Non-append mutations (deleting or overwriting values) are OK when the
        # values aren't tracked.
        a = module.Module()
        a.d = {}
        a.d["a"] = [3]
        a.d[1] = 3
        a.d[1] = 2
        self.assertEqual(2, a.d[1])
        del a.d[1]
        a.d[2] = data_structures.NoDependency(module.Module())
        second = module.Module()
        a.d[2] = data_structures.NoDependency(second)
        self.assertIs(second, a.d[2])
        self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
        model = training.Model()
        model.sub = a
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        model.save_weights(save_path)
        model.load_weights(save_path)

    def testPopNoSave(self):
        model = training.Model()
        model.d = {}
        model.d["a"] = []
        model.d.pop("a")
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        with self.assertRaisesRegex(ValueError, "Unable to save"):
            model.save_weights(save_path)

    def testExternalModificationNoSave(self):
        model = training.Model()
        external_reference = {}
        model.d = external_reference
        external_reference["a"] = []
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        with self.assertRaisesRegex(ValueError,
                                    "modified outside the wrapper"):
            model.save_weights(save_path)

    def testOverwriteCanStillSave(self):
        model = training.Model()
        model.d = {}
        model.d["a"] = {}
        model.d["a"] = {}
        save_path = os.path.join(self.get_temp_dir(), "ckpt")
        model.save_weights(save_path)

    def testIter(self):
        model = training.Model()
        model.d = {1: 3}
        model.d[1] = 3
        self.assertEqual([1], list(model.d))
        new_dict = {}
        # This update() is super tricky. If the dict wrapper subclasses dict,
        # CPython will access its storage directly instead of calling any
        # methods/properties on the object. So the options are either not to
        # subclass dict (in which case update will call normal iter methods, but the
        # object won't pass isinstance checks) or to subclass dict and keep that
        # storage updated (no shadowing all its methods like ListWrapper).
        new_dict.update(model.d)
        self.assertEqual({1: 3}, new_dict)
Beispiel #19
0
class TestSaveModel(test.TestCase, parameterized.TestCase):
    def setUp(self):
        super(TestSaveModel, self).setUp()
        self.model = testing_utils.get_small_sequential_mlp(1, 2, 3)
        self.subclassed_model = testing_utils.get_small_subclass_mlp(1, 2)

    def assert_h5_format(self, path):
        if h5py is not None:
            self.assertTrue(
                h5py.is_hdf5(path),
                'Model saved at path {} is not a valid hdf5 file.'.format(
                    path))

    def assert_saved_model(self, path):
        loader_impl.parse_saved_model(path)

    @testing_utils.run_v2_only
    def test_save_format_defaults(self):
        path = os.path.join(self.get_temp_dir(), 'model_path')
        save.save_model(self.model, path)
        self.assert_saved_model(path)

    @testing_utils.run_v2_only
    def test_save_format_defaults_pathlib(self):
        if sys.version_info < (3, 6):
            self.skipTest(
                'pathlib is only available for python version >= 3.6')
        path = pathlib.Path(self.get_temp_dir()) / 'model_path'
        save.save_model(self.model, path)
        self.assert_saved_model(path)

    @testing_utils.run_v2_only
    def test_save_hdf5(self):
        path = os.path.join(self.get_temp_dir(), 'model')
        save.save_model(self.model, path, save_format='h5')
        self.assert_h5_format(path)
        with self.assertRaisesRegex(
                NotImplementedError,
                'requires the model to be a Functional model or a Sequential model.'
        ):
            save.save_model(self.subclassed_model, path, save_format='h5')

    @testing_utils.run_v2_only
    def test_save_load_hdf5_pathlib(self):
        if sys.version_info < (3, 6):
            self.skipTest(
                'pathlib is only available for python version >= 3.6')
        path = pathlib.Path(self.get_temp_dir()) / 'model'
        save.save_model(self.model, path, save_format='h5')
        save.load_model(path)

    @testing_utils.run_v2_only
    def test_save_tf(self):
        path = os.path.join(self.get_temp_dir(), 'model')
        save.save_model(self.model, path, save_format='tf')
        self.assert_saved_model(path)
        with self.assertRaisesRegex(ValueError,
                                    'input shapes have not been set'):
            save.save_model(self.subclassed_model, path, save_format='tf')
        self.subclassed_model.predict(np.random.random((3, 5)))
        save.save_model(self.subclassed_model, path, save_format='tf')
        self.assert_saved_model(path)

    @testing_utils.run_v2_only
    def test_save_load_tf_string(self):
        path = os.path.join(self.get_temp_dir(), 'model')
        save.save_model(self.model, path, save_format='tf')
        save.load_model(path)

    @testing_utils.run_v2_only
    def test_save_load_tf_pathlib(self):
        if sys.version_info < (3, 6):
            self.skipTest(
                'pathlib is only available for python version >= 3.6')
        path = pathlib.Path(self.get_temp_dir()) / 'model'
        save.save_model(self.model, path, save_format='tf')
        save.load_model(path)

    @testing_utils.run_v2_only
    def test_save_load_weights_tf_pathlib(self):
        if sys.version_info < (3, 6):
            self.skipTest(
                'pathlib is only available for python version >= 3.6')
        path = pathlib.Path(self.get_temp_dir()) / 'model'
        self.model.save_weights(path, save_format='tf')
        self.model.load_weights(path)

    @testing_utils.run_v2_only
    def test_save_load_weights_hdf5_pathlib(self):
        if sys.version_info < (3, 6):
            self.skipTest(
                'pathlib is only available for python version >= 3.6')
        path = pathlib.Path(self.get_temp_dir()) / 'model'
        self.model.save_weights(path, save_format='h5')
        self.model.load_weights(path)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_saving_with_dense_features(self):
        cols = [
            feature_column_lib.numeric_column('a'),
            feature_column_lib.indicator_column(
                feature_column_lib.categorical_column_with_vocabulary_list(
                    'b', ['one', 'two']))
        ]
        input_layers = {
            'a': keras.layers.Input(shape=(1, ), name='a'),
            'b': keras.layers.Input(shape=(1, ), name='b', dtype='string')
        }

        fc_layer = dense_features.DenseFeatures(cols)(input_layers)
        output = keras.layers.Dense(10)(fc_layer)

        model = keras.models.Model(input_layers, output)

        model.compile(loss=keras.losses.MSE,
                      optimizer='rmsprop',
                      metrics=[keras.metrics.categorical_accuracy])

        config = model.to_json()
        loaded_model = model_config.model_from_json(config)

        inputs_a = np.arange(10).reshape(10, 1)
        inputs_b = np.arange(10).reshape(10, 1).astype('str')

        with self.cached_session():
            # Initialize tables for V1 lookup.
            if not context.executing_eagerly():
                self.evaluate(lookup_ops.tables_initializer())

            self.assertLen(
                loaded_model.predict({
                    'a': inputs_a,
                    'b': inputs_b
                }), 10)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_saving_with_sequence_features(self):
        cols = [
            feature_column_lib.sequence_numeric_column('a'),
            feature_column_lib.indicator_column(
                feature_column_lib.
                sequence_categorical_column_with_vocabulary_list(
                    'b', ['one', 'two']))
        ]
        input_layers = {
            'a':
            keras.layers.Input(shape=(None, 1), sparse=True, name='a'),
            'b':
            keras.layers.Input(shape=(None, 1),
                               sparse=True,
                               name='b',
                               dtype='string')
        }

        fc_layer, _ = ksfc.SequenceFeatures(cols)(input_layers)
        # TODO(tibell): Figure out the right dtype and apply masking.
        # sequence_length_mask = array_ops.sequence_mask(sequence_length)
        # x = keras.layers.GRU(32)(fc_layer, mask=sequence_length_mask)
        x = keras.layers.GRU(32)(fc_layer)
        output = keras.layers.Dense(10)(x)

        model = keras.models.Model(input_layers, output)

        model.compile(loss=keras.losses.MSE,
                      optimizer='rmsprop',
                      metrics=[keras.metrics.categorical_accuracy])

        config = model.to_json()
        loaded_model = model_config.model_from_json(config)

        batch_size = 10
        timesteps = 1

        values_a = np.arange(10, dtype=np.float32)
        indices_a = np.zeros((10, 3), dtype=np.int64)
        indices_a[:, 0] = np.arange(10)
        inputs_a = sparse_tensor.SparseTensor(indices_a, values_a,
                                              (batch_size, timesteps, 1))

        values_b = np.zeros(10, dtype=np.str)
        indices_b = np.zeros((10, 3), dtype=np.int64)
        indices_b[:, 0] = np.arange(10)
        inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
                                              (batch_size, timesteps, 1))

        with self.cached_session():
            # Initialize tables for V1 lookup.
            if not context.executing_eagerly():
                self.evaluate(lookup_ops.tables_initializer())

            self.assertLen(
                loaded_model.predict({
                    'a': inputs_a,
                    'b': inputs_b
                }, steps=1), batch_size)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_saving_h5_for_rnn_layers(self):
        # See https://github.com/tensorflow/tensorflow/issues/35731 for details.
        inputs = keras.Input([10, 91], name='train_input')
        rnn_layers = [
            keras.layers.LSTMCell(size,
                                  recurrent_dropout=0,
                                  name='rnn_cell%d' % i)
            for i, size in enumerate([512, 512])
        ]
        rnn_output = keras.layers.RNN(rnn_layers,
                                      return_sequences=True,
                                      name='rnn_layer')(inputs)
        pred_feat = keras.layers.Dense(91,
                                       name='prediction_features')(rnn_output)
        pred = keras.layers.Softmax()(pred_feat)
        model = keras.Model(inputs=[inputs], outputs=[pred, pred_feat])
        path = os.path.join(self.get_temp_dir(), 'model_path.h5')
        model.save(path)

        # Make sure the variable name is unique.
        self.assertNotEqual(rnn_layers[0].kernel.name,
                            rnn_layers[1].kernel.name)
        self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_saving_optimizer_weights(self):
        class MyModel(keras.Model):
            def __init__(self):
                super(MyModel, self).__init__()
                self.layer = keras.layers.Dense(1)

            def call(self, x):
                return self.layer(x)

        path = os.path.join(self.get_temp_dir(), 'weights_path')
        x, y = np.ones((10, 10)), np.ones((10, 1))

        model = MyModel()
        model.compile('rmsprop', loss='bce')
        model.train_on_batch(x, y)
        model.reset_metrics()
        model.save_weights(path, save_format='tf')

        batch_loss = model.train_on_batch(x, y)

        new_model = MyModel()
        new_model.compile('rmsprop', loss='bce')
        new_model.train_on_batch(x, y)
        new_model.reset_metrics()

        new_model.load_weights(path)
        new_batch_loss = new_model.train_on_batch(x, y)

        self.assertAllClose(batch_loss, new_batch_loss)

    @combinations.generate(combinations.combine(mode=['graph', 'eager']))
    def test_saving_model_with_custom_object(self):
        with generic_utils.custom_object_scope():

            @generic_utils.register_keras_serializable()
            class CustomLoss(losses.MeanSquaredError):
                pass

            model = sequential.Sequential(
                [core.Dense(units=1, input_shape=(1, ))])
            model.compile(optimizer='sgd', loss=CustomLoss())
            model.fit(np.zeros([10, 1]), np.zeros([10, 1]))

            temp_dir = self.get_temp_dir()
            filepath = os.path.join(temp_dir, 'saving')
            model.save(filepath)

            # Make sure the model can be correctly load back.
            _ = save.load_model(filepath, compile=True)
Beispiel #20
0
class TupleTests(keras_parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testTracking(self):
        with self.test_session():
            model = HasTuple()
            output = model(array_ops.ones([32, 2]))
            self.assertAllEqual([32, 5], output.shape.as_list())
            self.assertLen(model.layers, 4)
            self.assertLen(model.layer_list.layers, 3)
            self.assertEqual(
                len(model.layers),
                len(
                    tuple(model.layer_list.layers) +
                    model.layers_with_updates))
            self.assertEqual(3, model.layer_list.layers[0].units)
            self.assertEqual(4, model.layer_list.layers[1].units)
            self.assertEqual(5, model.layer_list.layers[2].units)
            self.assertLen(model._checkpoint_dependencies, 2)
            self.assertIs(model.layer_list,
                          model._checkpoint_dependencies[0].ref)
            self.assertIs(model.layers_with_updates,
                          model._checkpoint_dependencies[1].ref)
            self.assertLen(
                model._checkpoint_dependencies[0].ref._checkpoint_dependencies,
                3)
            self.evaluate([v.initializer for v in model.variables])
            self.evaluate(model.variables[0].assign([[1., 2., 3.],
                                                     [4., 5., 6.]]))
            save_path = os.path.join(self.get_temp_dir(), "ckpt")
            model.save_weights(save_path)
            self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
            model.load_weights(save_path)
            self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
                                self.evaluate(model.variables[0]))
            v = variables.Variable(1.)
            model.var_list = (v, )
            self.assertIn(id(v), [id(obj) for obj in model.variables])
            self.assertIn(id(v),
                          [id(obj) for obj in model.trainable_variables])
            self.assertNotIn(
                id(v), [id(obj) for obj in model.non_trainable_variables])
            self.assertIn(id(model.layer_list[0].trainable_weights[0]),
                          [id(obj) for obj in model.trainable_weights])

    @parameterized.named_parameters(
        ("Module", module.Module),
        ("Model", training.Model),
    )
    def testSubModelTracking(self, module_subclass):
        model = module_subclass()
        model.v = variables.Variable(1.)
        self.assertIn(model.v, model.trainable_variables)
        model2 = module_subclass()
        model2.m = (model, )
        self.assertIn(model.v, model2.trainable_variables)

    def testSubSequentialTracking(self):
        class _Subclassed(training.Model):
            def __init__(self, wrapped):
                super(_Subclassed, self).__init__()
                self._wrapped = wrapped

            def call(self, x):
                return self._wrapped(x)

        model = sequential.Sequential()
        layer = core.Dense(1)
        model.add(layer)
        model2 = _Subclassed(model)
        model2(array_ops.ones([1, 2]))
        model2.m = (model, )
        self.assertIn(layer.kernel, model2.trainable_weights)

    def testUpdatesForwarded(self):
        with ops.Graph().as_default():
            model = HasTuple()
            model_input = array_ops.ones([32, 2])
            model(model_input)
            self.assertNotEmpty(model.layers_with_updates[0].updates)
            self.assertEqual(set(model.layers_with_updates[0].updates),
                             set(model.updates))

        model = HasTuple()
        model_input = array_ops.ones([32, 2])
        model(model_input)
        self.assertEmpty(model.updates)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testLossesForwarded(self):
        model = HasTuple()
        model_input = array_ops.ones([32, 2])
        model(model_input)
        self.assertLen(model.losses, 1)

    def testModelContainersCompareEqual(self):
        class HasEqualContainers(training.Model):
            def __init__(self):
                super(HasEqualContainers, self).__init__()
                self.l1 = ()
                self.l2 = ()

        model = HasEqualContainers()
        first_layer = HasEqualContainers()
        model.l1 = (first_layer, )
        second_layer = HasEqualContainers()
        model.l2 = (second_layer, )
        self.assertEqual((first_layer, ), model.l1)
        d = {model.l1: 1, model.l2: 2}
        self.assertEqual(1, d[model.l1])
        self.assertEqual(1, d[(first_layer, )])
        self.assertEqual(2, d[model.l2])
        self.assertEqual(2, d[(second_layer, )])
        self.assertEqual([first_layer, second_layer], model.layers)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testTensorConversion(self):
        class TupleToTensor(training.Model):
            def __init__(self):
                super(TupleToTensor, self).__init__()
                self.l = (1., 2., 3.)

        self.assertAllEqual(
            (1., 2., 3.),
            self.evaluate(constant_op.constant(TupleToTensor().l)))

        self.assertAllEqual(
            (1., 2., 3.),
            self.evaluate(gen_array_ops.Pack(values=TupleToTensor().l)))
Beispiel #21
0
class CheckpointCompatibilityTests(keras_parameterized.TestCase):
    def _initialized_model(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        optimizer = adam.Adam(0.001)
        root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                    model=model)
        with backprop.GradientTape() as tape:
            loss = model(input_value)
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        train_op = optimizer.apply_gradients(zip(gradients, variables))
        self.evaluate(trackable_utils.gather_initializers(root_trackable))
        self.evaluate(train_op)
        # A regular variable, a slot variable, and a non-slot Optimizer variable
        # with known values to check when loading.
        self.evaluate(model._named_dense.bias.assign([1.]))
        self.evaluate(
            optimizer.get_slot(var=model._named_dense.bias,
                               slot_name="m").assign([2.]))
        self.evaluate(optimizer.beta_1.assign(3.))
        return root_trackable

    def _set_sentinels(self, root_trackable):
        self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
        self.evaluate(
            root_trackable.optimizer.get_slot(
                var=root_trackable.model._named_dense.bias,
                slot_name="m").assign([102.]))
        self.evaluate(root_trackable.optimizer.beta_1.assign(103.))

    def _check_sentinels(self, root_trackable):
        self.assertAllEqual([1.],
                            self.evaluate(
                                root_trackable.model._named_dense.bias))
        self.assertAllEqual([2.],
                            self.evaluate(
                                root_trackable.optimizer.get_slot(
                                    var=root_trackable.model._named_dense.bias,
                                    slot_name="m")))
        self.assertAllEqual(3., self.evaluate(root_trackable.optimizer.beta_1))

    def _write_name_based_checkpoint(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(
                    graph=save_graph) as session:
                root = self._initialized_model()
                name_saver = saver_lib.Saver()
                return name_saver.save(sess=session,
                                       save_path=checkpoint_prefix,
                                       global_step=root.optimizer.iterations)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testLoadFromNameBasedSaver(self):
        """Save a name-based checkpoint, load it using the object-based API."""
        with testing_utils.device(should_use_gpu=True):
            with self.test_session():
                save_path = self._write_name_based_checkpoint()
                root = self._initialized_model()
                self._set_sentinels(root)
                with self.assertRaises(AssertionError):
                    self._check_sentinels(root)
                object_saver = trackable_utils.TrackableSaver(
                    graph_view.ObjectGraphView(root))
                self._set_sentinels(root)
                status = object_saver.restore(save_path)
                if context.executing_eagerly():
                    self._check_sentinels(root)
                if context.executing_eagerly():
                    status.assert_consumed()
                    status.assert_existing_objects_matched()
                    status.assert_nontrivial_match()
                else:
                    # When graph building, we haven't read any keys, so we don't know
                    # whether the restore will be complete.
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_consumed()
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_existing_objects_matched()
                    with self.assertRaisesRegex(AssertionError,
                                                "not restored"):
                        status.assert_nontrivial_match()
                status.run_restore_ops()
                self._check_sentinels(root)
                self._set_sentinels(root)
                status = object_saver.restore(save_path)
                status.initialize_or_restore()
                status.assert_nontrivial_match()
                self._check_sentinels(root)
                # Check that there is no error when keys are missing from the name-based
                # checkpoint.
                root.not_in_name_checkpoint = variables_lib.Variable([1.])
                status = object_saver.restore(save_path)
                with self.assertRaises(AssertionError):
                    status.assert_existing_objects_matched()

    def testSaveGraphLoadEager(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(graph=save_graph):
                root = self._initialized_model()
                save_path = root.save(file_prefix=checkpoint_prefix)
        with context.eager_mode():
            root = self._initialized_model()
            self._set_sentinels(root)
            root.restore(save_path).assert_consumed()
            self._check_sentinels(root)

    def testSaveEagerLoadGraph(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with context.eager_mode():
            root = self._initialized_model()
            save_path = root.save(file_prefix=checkpoint_prefix)
        with context.graph_mode():
            save_graph = ops.Graph()
            with save_graph.as_default(), self.session(graph=save_graph):
                root = self._initialized_model()
                self._set_sentinels(root)
                root.restore(save_path).assert_consumed().run_restore_ops()
                self._check_sentinels(root)

    def testIgnoreSaveCounter(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        with self.cached_session() as session:
            # Create and save a model using Saver() before using a Checkpoint. This
            # generates a snapshot without the Checkpoint's `save_counter`.
            model = sequential.Sequential()
            model.add(core.Flatten(input_shape=(1, )))
            model.add(core.Dense(1))
            name_saver = saver_lib.Saver(model.trainable_variables)
            save_path = name_saver.save(sess=session,
                                        save_path=checkpoint_prefix,
                                        global_step=1)
            # Checkpoint.restore must successfully load that checkpoint.
            ckpt = trackable_utils.Checkpoint(model=model)
            status = ckpt.restore(save_path)
            status.assert_existing_objects_matched()
            # It should, however, refuse to load a checkpoint where an unrelated
            # `save_counter` variable is missing.
            model.layers[1].var = variables_lib.Variable(0.,
                                                         name="save_counter")
            status = ckpt.restore(save_path)
            with self.assertRaises(AssertionError):
                status.assert_existing_objects_matched()
Beispiel #22
0
class ListTests(keras_parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testTracking(self):
        with self.test_session():
            model = HasList()
            output = model(array_ops.ones([32, 2]))
            self.assertAllEqual([32, 12], output.shape)
            self.assertEqual(11, len(model.layers))
            self.assertEqual(10, len(model.layer_list.layers))
            self.assertEqual(
                len(model.layers),
                len(model.layer_list.layers + model.layers_with_updates))
            for index in range(10):
                self.assertEqual(3 + index,
                                 model.layer_list.layers[index].units)
            self.assertEqual(2, len(model._checkpoint_dependencies))
            self.assertIs(model.layer_list,
                          model._checkpoint_dependencies[0].ref)
            self.assertIs(model.layers_with_updates,
                          model._checkpoint_dependencies[1].ref)
            self.assertEqual(
                10,
                len(model._checkpoint_dependencies[0].ref.
                    _checkpoint_dependencies))
            self.evaluate([v.initializer for v in model.variables])
            self.evaluate(model.variables[0].assign([[1., 2., 3.],
                                                     [4., 5., 6.]]))
            save_path = os.path.join(self.get_temp_dir(), "ckpt")
            model.save_weights(save_path)
            self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
            model.load_weights(save_path)
            self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
                                self.evaluate(model.variables[0]))
            v = variables.Variable(1.)
            model.var_list = [v]
        self.assertTrue(any(v is t for t in model.variables))
        self.assertTrue(any(v is t for t in model.trainable_variables))
        self.assertFalse(any(v is t for t in model.non_trainable_variables))
        self.assertTrue(
            any(model.layer_list[0].trainable_weights[0] is t
                for t in model.trainable_weights))

    def testSubModelTracking(self):
        model = training.Model()
        model.v = variables.Variable(1.)
        self.assertIn(model.v, model.trainable_weights)
        model2 = training.Model()
        model2.m = [model]
        self.assertIn(model.v, model2.trainable_weights)

    def testSubSequentialTracking(self):
        class _Subclassed(training.Model):
            def __init__(self, wrapped):
                super(_Subclassed, self).__init__()
                self._wrapped = wrapped

            def call(self, x):
                return self._wrapped(x)

        model = sequential.Sequential()
        layer = core.Dense(1)
        model.add(layer)
        model2 = _Subclassed(model)
        model2(array_ops.ones([1, 2]))
        model2.m = [model]
        self.assertIn(layer.kernel, model2.trainable_weights)

    def testLayerTrackedThroughSequential(self):
        class AttrDict(dict):
            def __init__(self, *args, **kwargs):
                super(AttrDict, self).__init__(*args, **kwargs)
                self.__dict__ = self

        def ffnet(layer_sizes, name):
            ff = sequential.Sequential(name=name)
            for i, width in enumerate(layer_sizes):
                ff.add(
                    core.Dense(width,
                               activation=("relu" if i < len(layer_sizes) - 1
                                           else None)))
            return ff

        class MyModel2(training.Model):
            def __init__(self, config, name="my_model_2"):
                super(MyModel2, self).__init__(name=name)
                self._num_tokens = config.num_tokens

                # list of sub-models
                self._ffnet = [
                    ffnet(config.module_layers + (self._num_tokens, ), "ff")
                ]

            def null_input(self):
                return array_ops.zeros([1, self._num_tokens],
                                       dtype=dtypes.float32)

            def call(self, input_, module_index=None):
                return self._ffnet[0](input_)

        m2 = MyModel2(AttrDict(num_tokens=5, module_layers=(50, 30)))

        # Construct
        m2(m2.null_input())
        self.assertLen(m2.trainable_variables, 6)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testUpdatesForwarded(self):
        model = HasList()
        model_input = array_ops.ones([32, 2])
        model(model_input)
        if context.executing_eagerly():
            self.assertEqual(0, len(model.updates))
        else:
            self.assertGreater(len(model.layers_with_updates[0].updates), 0)
            self.assertEqual(set(model.layers_with_updates[0].updates),
                             set(model.updates))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testLossesForwarded(self):
        model = HasList()
        model_input = array_ops.ones([32, 2])
        model(model_input)
        self.assertEqual(2, len(model.losses))

    def testModelContainersCompareEqual(self):
        class HasEqualContainers(training.Model):
            def __init__(self):
                super(HasEqualContainers, self).__init__()
                self.l1 = []
                self.l2 = []

        model = HasEqualContainers()
        first_layer = HasEqualContainers()
        model.l1.append(first_layer)
        second_layer = HasEqualContainers()
        model.l2.append(second_layer)
        self.assertEqual([first_layer, second_layer], model.layers)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testTensorConversion(self):
        class ListToTensor(training.Model):
            def __init__(self):
                super(ListToTensor, self).__init__()
                self.l = [1., 2., 3.]

        self.assertAllEqual([1., 2., 3.],
                            self.evaluate(
                                constant_op.constant(ListToTensor().l)))

        self.assertAllEqual([1., 2., 3.],
                            self.evaluate(
                                gen_array_ops.Pack(values=ListToTensor().l)))
class BatchNormalizationTest(keras_parameterized.TestCase):

  @keras_parameterized.run_all_keras_modes
  def test_basic_batchnorm(self):
    testing_utils.layer_test(
        keras.layers.BatchNormalization,
        kwargs={
            'momentum': 0.9,
            'epsilon': 0.1,
            'gamma_regularizer': keras.regularizers.l2(0.01),
            'beta_regularizer': keras.regularizers.l2(0.01)
        },
        input_shape=(3, 4, 2))
    testing_utils.layer_test(
        keras.layers.BatchNormalization,
        kwargs={
            'gamma_initializer': 'ones',
            'beta_initializer': 'ones',
            'moving_mean_initializer': 'zeros',
            'moving_variance_initializer': 'ones'
        },
        input_shape=(3, 4, 2))
    testing_utils.layer_test(
        keras.layers.BatchNormalization,
        kwargs={'scale': False,
                'center': False},
        input_shape=(3, 3))
    testing_utils.layer_test(
        keras.layers.BatchNormalization,
        kwargs={
            'gamma_initializer': 'ones',
            'beta_initializer': 'ones',
            'moving_mean_initializer': 'zeros',
            'moving_variance_initializer': 'ones'
        },
        input_shape=(3, 2, 4, 2))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_batchnorm_weights(self):
    layer = keras.layers.BatchNormalization(scale=False, center=False)
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.trainable_weights), 0)
    self.assertEqual(len(layer.weights), 2)

    layer = keras.layers.BatchNormalization()
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.trainable_weights), 2)
    self.assertEqual(len(layer.weights), 4)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_batchnorm_regularization(self):
    layer = keras.layers.BatchNormalization(
        gamma_regularizer='l1', beta_regularizer='l1')
    layer.build((None, 3, 4))
    self.assertEqual(len(layer.losses), 2)
    max_norm = keras.constraints.max_norm
    layer = keras.layers.BatchNormalization(
        gamma_constraint=max_norm, beta_constraint=max_norm)
    layer.build((None, 3, 4))
    self.assertEqual(layer.gamma.constraint, max_norm)
    self.assertEqual(layer.beta.constraint, max_norm)

  @keras_parameterized.run_all_keras_modes
  def test_batchnorm_convnet(self):
    if test.is_gpu_available(cuda_only=True):
      with self.session():
        model = keras.models.Sequential()
        norm = keras.layers.BatchNormalization(
            axis=1, input_shape=(3, 4, 4), momentum=0.8)
        model.add(norm)
        model.compile(
            loss='mse',
            optimizer=gradient_descent.GradientDescentOptimizer(0.01),
            run_eagerly=testing_utils.should_run_eagerly())

        # centered on 5.0, variance 10.0
        x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
        model.fit(x, x, epochs=4, verbose=0)
        out = model.predict(x)
        out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1))
        out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1))

        np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
        np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)

  @keras_parameterized.run_all_keras_modes
  def test_batchnorm_convnet_channel_last(self):
    model = keras.models.Sequential()
    norm = keras.layers.BatchNormalization(
        axis=-1, input_shape=(4, 4, 3), momentum=0.8)
    model.add(norm)
    model.compile(
        loss='mse',
        optimizer=gradient_descent.GradientDescentOptimizer(0.01),
        run_eagerly=testing_utils.should_run_eagerly())

    # centered on 5.0, variance 10.0
    x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
    model.fit(x, x, epochs=4, verbose=0)
    out = model.predict(x)
    out -= np.reshape(keras.backend.eval(norm.beta), (1, 1, 1, 3))
    out /= np.reshape(keras.backend.eval(norm.gamma), (1, 1, 1, 3))

    np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
    np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)

  @keras_parameterized.run_all_keras_modes
  def test_batchnorm_correctness(self):
    _run_batchnorm_correctness_test(
        normalization.BatchNormalization, dtype='float32')
    _run_batchnorm_correctness_test(
        normalization_v2.BatchNormalization, dtype='float32')

  @keras_parameterized.run_all_keras_modes
  def test_batchnorm_float16(self):
    _run_batchnorm_correctness_test(
        normalization.BatchNormalization, dtype='float16')
    _run_batchnorm_correctness_test(
        normalization_v2.BatchNormalization, dtype='float16')

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  @testing_utils.enable_v2_dtype_behavior
  def test_batchnorm_mixed_precision(self):
    norm = keras.layers.BatchNormalization(
        axis=-1,
        input_shape=(4, 4, 3),
        momentum=0.8,
        dtype=policy.Policy('mixed_float16'))
    x = np.random.normal(size=(10, 4, 4, 3))
    y = norm(x)
    self.assertEqual(y.dtype, 'float16')
    self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
    self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')

  @combinations.generate(combinations.combine(mode=['graph', 'eager'],
                                              fused=[True, False]))
  @testing_utils.enable_v2_dtype_behavior
  def test_batchnorm_mixed_precision_does_not_overflow(self, fused):
    norm = keras.layers.BatchNormalization(
        axis=-1,
        input_shape=(1, 1, 1),
        fused=fused,
        dtype=policy.Policy('mixed_float16'))
    x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1))
    y = norm(x, training=True)
    expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1))
    self.assertAllClose(keras.backend.eval(y), expected_y)

  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
  def test_batchnorm_non_trainable_with_fit(self):
    # We use the same data shape for all the data we use in this test.
    # This will prevent any used tf.functions from retracing.
    # This helps us verify that changing trainable and recompiling really
    # does update the training loop, rather than a different data shape
    # triggering a retrace.
    data_shape = (100, 3)

    inputs = keras.Input((3,))
    bn = normalization_v2.BatchNormalization()
    outputs = bn(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(
        'rmsprop',
        'mse',
        run_eagerly=testing_utils.should_run_eagerly())
    model.fit(np.random.random(data_shape), np.random.random(data_shape))

    test_data = np.random.random(data_shape)
    test_targets = np.random.random(data_shape)
    test_loss = model.evaluate(test_data, test_targets)

    bn.trainable = False
    model.compile(
        'rmsprop',
        'mse',
        run_eagerly=testing_utils.should_run_eagerly())
    train_loss = model.train_on_batch(test_data, test_targets)
    self.assertAlmostEqual(test_loss, train_loss)

  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
  def test_eager_batchnorm_in_custom_model_call_with_tf_function(self):

    class MyModel(keras.Model):

      def __init__(self):
        super(MyModel, self).__init__()
        self.bn = keras.layers.BatchNormalization()

      @def_function.function()
      def call(self, x, training):
        return self.bn(x, training=training)

    model = MyModel()

    for _ in range(10):
      x = constant_op.constant(0.5, shape=[1, 1])
      model(x, training=True)

    # Make sure the moving mean and variance have been updated
    self.assertAllClose(model.bn.moving_mean.numpy(), [0.047], atol=3e-3)
    self.assertAllClose(model.bn.moving_variance.numpy(), [0.9], atol=3e-2)

  @combinations.generate(combinations.combine(mode=['eager']))
  def test_bessels_correction(self):
    # Bessel's correction is currently only used in the fused case. In the
    # future, it may be used in the nonfused case as well.

    x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1])
    layer = normalization_v2.BatchNormalization(
        momentum=0.5, moving_variance_initializer='zeros')
    layer(x, training=True)
    self.assertTrue(layer.fused)
    # Since fused is used, Bessel's correction is used. The variance of [0, 2]
    # is 2 with Bessel's correction. Since the momentum is 0.5, the variance is
    # 2 * 0.5 == 1.
    self.assertAllEqual(self.evaluate(layer.moving_variance), [1.])

    x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1, 1])
    layer = normalization_v2.BatchNormalization(
        momentum=0.5, moving_variance_initializer='zeros')
    layer(x, training=True)
    self.assertFalse(layer.fused)
    # Since fused is not used, Bessel's correction is not used. The variance of
    # [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the
    # variance is 1 * 0.5 == 0.5.
    self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
Beispiel #24
0
class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):

  def test_keras_optimizer_warning(self):
    graph = ops.Graph()
    with graph.as_default(), self.session(graph):
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.Dense(3))
      model.compile(loss='mse', optimizer=optimizer_v1.Adam(), metrics=['acc'])
      if not ops.executing_eagerly_outside_functions():
        model._make_train_function()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')
      with test.mock.patch.object(logging, 'warning') as mock_log:
        model.save_weights(prefix)
        self.assertRegex(str(mock_log.call_args), 'Keras optimizer')

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_tensorflow_format_overwrite(self):
    with self.cached_session() as session:
      model = SubclassedModel()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')

      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
      executing_eagerly = context.executing_eagerly()
      model(x)  # pylint: disable=not-callable
      if not executing_eagerly:
        session.run([v.initializer for v in model.variables])
      model.save_weights(prefix, save_format='tensorflow')
      model.save_weights(prefix, save_format='tensorflow', overwrite=True)
      with self.assertRaises(EOFError):
        # Indirectly tests that the user is prompted
        model.save_weights(prefix, save_format='tensorflow', overwrite=False)

  def test_no_default_session(self):
    with ops.Graph().as_default():
      self.assertFalse(ops.get_default_session())
      data = np.random.random((1000, 32)).astype(np.float32)
      labels = np.random.random((1000, 10)).astype(np.float32)

      model = keras.models.Sequential([
          keras.layers.Dense(10, activation='softmax'),
          keras.layers.Dense(10, activation='softmax')])

      model.compile(optimizer=training_module.RMSPropOptimizer(0.001),
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

      model.fit(data, labels)
      fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt')
      model.save_weights(fname)
      model.load_weights(fname)

  def test_no_graph_pollution(self):
    with ops.get_default_graph().as_default():
      graph = ops.Graph()
      with graph.as_default(), self.session(graph) as session:
        model = SubclassedModel()
        temp_dir = self.get_temp_dir()
        prefix = os.path.join(temp_dir, 'ckpt')

        x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
        model(x)  # pylint: disable=not-callable
        session.run([v.initializer for v in model.variables])
        model.save_weights(prefix, save_format='tensorflow')
        op_count = len(graph.get_operations())
        model.save_weights(prefix, save_format='tensorflow')
        self.assertLen(graph.get_operations(), op_count)

        model.load_weights(prefix)
        op_count = len(graph.get_operations())
        model.load_weights(prefix)
        self.assertLen(graph.get_operations(), op_count)

  def _weight_loading_test_template(self, make_model_fn):
    with self.cached_session():
      model = make_model_fn()
      model.compile(
          loss='mse',
          optimizer=training_module.RMSPropOptimizer(0.1),
          metrics=['acc', keras.metrics.CategoricalAccuracy()])
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')
      train_x = np.random.random((3, 2))
      train_y = np.random.random((3,))
      x = constant_op.constant(train_x, dtype=dtypes.float32)

      model.train_on_batch(train_x, train_y)
      model.save_weights(prefix, save_format='tf')
      ref_y_before_train = model.predict(train_x)
      model.train_on_batch(train_x, train_y)
      ref_y_after_train = model.predict(train_x)
      for v in model.variables:
        self.evaluate(
            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))

      self.addCleanup(shutil.rmtree, temp_dir)

      model.load_weights(prefix)
      self.assertAllClose(ref_y_before_train, self.evaluate(model(x)))

      # Test restore-on-create if this is a subclassed Model (graph Networks
      # will have already created their variables).
      load_model = make_model_fn()
      load_model.load_weights(prefix)
      self.assertAllClose(
          ref_y_before_train,
          self.evaluate(load_model(x)))
      load_model = make_model_fn()
      load_model.load_weights(prefix)
      # We need to run some of the restore ops for predict(), but not all
      # variables have been created yet (optimizer slot variables). Tests
      # incremental restore.
      load_model.predict(train_x)
      load_model.compile(
          loss='mse',
          optimizer=training_module.RMSPropOptimizer(0.1),
          metrics=['acc', keras.metrics.CategoricalAccuracy()])
      load_model.train_on_batch(train_x, train_y)
      self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_weight_loading_graph_model(self):
    def _make_graph_model():
      a = keras.layers.Input(shape=(2,))
      x = keras.layers.Dense(3)(a)
      b = keras.layers.Dense(1)(x)
      return keras.models.Model(a, b)

    self._weight_loading_test_template(_make_graph_model)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_weight_loading_subclassed_model(self):
    self._weight_loading_test_template(SubclassedModel)

  def _new_layer_weight_loading_test_template(
      self, first_model_fn, second_model_fn):
    with self.cached_session() as session:
      model = first_model_fn()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')

      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
      executing_eagerly = context.executing_eagerly()
      ref_y_tensor = model(x)
      if not executing_eagerly:
        session.run([v.initializer for v in model.variables])
      ref_y = self.evaluate(ref_y_tensor)
      model.save_weights(prefix)
      self.assertEqual(
          prefix,
          checkpoint_management.latest_checkpoint(temp_dir))
      for v in model.variables:
        self.evaluate(
            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))

      self.addCleanup(shutil.rmtree, temp_dir)

      second_model = second_model_fn()
      status = second_model.load_weights(prefix)
      second_model(x)
      status.run_restore_ops()
      second_model.save_weights(prefix)
      # Check that the second model's checkpoint loads into the original model
      status = model.load_weights(prefix)
      status.run_restore_ops(session)
      y = self.evaluate(model(x))
      self.assertAllClose(ref_y, y)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_weight_loading_graph_model_added_layer(self):
    def _save_graph_model():
      a = keras.layers.Input(shape=(2,))
      x = keras.layers.Dense(3, name='first')(a)
      b = keras.layers.Dense(1, name='second')(x)
      return keras.models.Model(a, b)
    def _restore_graph_model():
      a = keras.layers.Input(shape=(2,))
      x = keras.layers.Dense(3, name='first')(a)
      y = keras.layers.Dense(1, name='second')(x)
      b = keras.layers.Dense(3, name='secondjr')(y)
      return keras.models.Model(a, b)

    self._new_layer_weight_loading_test_template(
        _save_graph_model, _restore_graph_model)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_weight_loading_graph_model_added_no_weight_layer(self):
    def _save_graph_model():
      a = keras.layers.Input(shape=(2,))
      x = keras.layers.Dense(3, name='first')(a)
      b = keras.layers.Dense(1, name='second')(x)
      return keras.models.Model(a, b)
    def _restore_graph_model():
      a = keras.layers.Input(shape=(2,))
      x = keras.layers.Dense(3, name='first')(a)
      b = keras.layers.Dense(1, name='second')(x)
      y = keras.layers.Dropout(rate=0.1)(b)
      return keras.models.Model(a, y)

    self._new_layer_weight_loading_test_template(
        _save_graph_model, _restore_graph_model)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_weight_loading_subclassed_model_added_layer(self):

    class SubclassedModelRestore(training.Model):

      def __init__(self):
        super(SubclassedModelRestore, self).__init__()
        self.x_layer = keras.layers.Dense(3)
        self.y_layer = keras.layers.Dense(3)
        self.b_layer = keras.layers.Dense(1)

      def call(self, a):
        return self.b_layer(self.y_layer(self.x_layer(a)))

    self._new_layer_weight_loading_test_template(
        SubclassedModel, SubclassedModelRestore)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_incompatible_checkpoint(self):
    save_path = trackable.Checkpoint().save(
        os.path.join(self.get_temp_dir(), 'ckpt'))
    m = DummySubclassModel()
    with self.assertRaisesRegex(AssertionError, 'Nothing to load'):
      m.load_weights(save_path)
    m.dense = keras.layers.Dense(2)
    m.dense(constant_op.constant([[1.]]))
    with self.assertRaisesRegex(AssertionError,
                                'Nothing except the root object matched'):
      m.load_weights(save_path)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_directory_passed(self):
    with self.cached_session():
      m = DummySubclassModel()
      v = m.add_weight(name='v', shape=[])
      self.evaluate(v.assign(42.))
      prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'ckpt/')
      m.save_weights(prefix)
      self.evaluate(v.assign(2.))
      m.load_weights(prefix)
      self.assertEqual(42., self.evaluate(v))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_relative_path(self):
    with self.cached_session():
      m = DummySubclassModel()
      v = m.add_weight(name='v', shape=[])
      os.chdir(self.get_temp_dir())

      prefix = 'ackpt'
      self.evaluate(v.assign(42.))
      m.save_weights(prefix)
      self.assertTrue(file_io.file_exists_v2('ackpt.index'))
      self.evaluate(v.assign(1.))
      m.load_weights(prefix)
      self.assertEqual(42., self.evaluate(v))

      prefix = 'subdir/ackpt'
      self.evaluate(v.assign(43.))
      m.save_weights(prefix)
      self.assertTrue(file_io.file_exists_v2('subdir/ackpt.index'))
      self.evaluate(v.assign(2.))
      m.load_weights(prefix)
      self.assertEqual(43., self.evaluate(v))

      prefix = 'ackpt/'
      self.evaluate(v.assign(44.))
      m.save_weights(prefix)
      self.assertTrue(file_io.file_exists_v2('ackpt/.index'))
      self.evaluate(v.assign(3.))
      m.load_weights(prefix)
      self.assertEqual(44., self.evaluate(v))

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_nonexistent_prefix_directory(self):
    with self.cached_session():
      m = DummySubclassModel()
      v = m.add_weight(name='v', shape=[])
      self.evaluate(v.assign(42.))
      prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'bckpt')
      m.save_weights(prefix)
      self.evaluate(v.assign(2.))
      m.load_weights(prefix)
      self.assertEqual(42., self.evaluate(v))
class LayerNormalizationNumericsTest(keras_parameterized.TestCase):
  """Tests LayerNormalization has correct and numerically stable outputs."""

  def _expected_layer_norm(self, x, beta, gamma, batch_input_shape, axis,
                           epsilon):
    """Returns the layer norm, which is computed using NumPy."""
    broadcast_shape = [batch_input_shape[i] if i in axis else 1
                       for i in range(len(batch_input_shape))]
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    expected = (x - mean) / np.sqrt(var + epsilon)
    expected *= np.reshape(gamma, broadcast_shape)
    expected += np.reshape(beta, broadcast_shape)
    return expected

  def _test_forward_pass(self, batch_input_shape, axis, fp64_tol=1e-14,
                         fp32_tol=1e-6, fp16_tol=1e-2):
    """Tests the forward pass of layer normalization.

    Args:
      batch_input_shape: The input shape that will be used to test, including
        the batch dimension.
      axis: A list of axises to normalize. Will be passed to the `axis` argument
        of LayerNormalization.
      fp64_tol: The relative and absolute tolerance for float64.
      fp32_tol: The relative and absolute tolerance for float32.
      fp16_tol: The relative and absolute tolerance for float16.
    """
    param_shape = [batch_input_shape[i] for i in axis]
    param_elems = 1
    for dim in param_shape:
      param_elems *= dim
    beta = np.arange(param_elems, dtype='float64').reshape(param_shape)
    gamma = np.arange(1, param_elems + 1, dtype='float64').reshape(param_shape)
    x = np.random.normal(size=batch_input_shape)

    for epsilon in 1e-12, 1e-3:
      expected = self._expected_layer_norm(x, beta, gamma, batch_input_shape,
                                           axis, epsilon)
      for dtype in 'float64', 'float32', 'float16':
        norm = normalization.LayerNormalization(
            axis=axis, dtype=dtype, batch_input_shape=batch_input_shape,
            epsilon=epsilon, beta_initializer=keras.initializers.constant(beta),
            gamma_initializer=keras.initializers.constant(gamma))
        y = norm(keras.backend.cast(x, dtype))
        actual = keras.backend.eval(y)

        if dtype == 'float64':
          tol = fp64_tol
        elif dtype == 'float32':
          tol = fp32_tol
        else:
          assert dtype == 'float16'
          tol = fp16_tol

        # We use absolute tolerances in addition to relative tolerances, because
        # some of the values are very close to zero.
        self.assertAllClose(expected, actual, rtol=tol, atol=tol)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_forward(self):
    # For numeric stability, we ensure the axis's dimension(s) have at least 4
    # elements.
    self._test_forward_pass((4, 3), (0,))
    self._test_forward_pass((3, 4), (1,))
    self._test_forward_pass((4, 3, 2), (0,))
    self._test_forward_pass((2, 4, 2), (1,))
    self._test_forward_pass((2, 3, 4), (2,), fp16_tol=5e-2)
    self._test_forward_pass((2, 3, 2), (0, 2))
    self._test_forward_pass((2, 2, 2, 2), (1, 3))
    self._test_forward_pass((2, 2, 2, 2), (2, 3))
    self._test_forward_pass((2, 3, 4, 5), (3,))

  def _test_backward_pass(self, batch_input_shape, axis, fp64_tol=1e-5,
                          fp32_tol=1e-5, fp16_tol=2e-2):
    """Tests the backwards pass of layer normalization.

    Args:
      batch_input_shape: The input shape that will be used to test, including
        the batch dimension.
      axis: A list of axises to normalize. Will be passed to the `axis` argument
        of LayerNormalization.
      fp64_tol: The relative and absolute tolerance for float64.
      fp32_tol: The relative and absolute tolerance for float32.
      fp16_tol: The relative and absolute tolerance for float16.
    """
    param_shape = [batch_input_shape[i] for i in axis]
    param_elems = 1
    for dim in param_shape:
      param_elems *= dim
    beta = np.arange(param_elems, dtype='float64').reshape(param_shape)
    gamma = np.arange(1, param_elems + 1, dtype='float64').reshape(param_shape)
    x = np.random.normal(size=batch_input_shape)

    for epsilon in 1e-12, 1e-3:
      # Float64 must come first in this list, as we use the float64 numerical
      # gradients to compare to the float32 and float16 symbolic gradients as
      # well. Computing float32/float16 numerical gradients is too numerically
      # unstable.
      for dtype in 'float64', 'float32', 'float16':
        norm = normalization.LayerNormalization(
            axis=axis, dtype=dtype, batch_input_shape=batch_input_shape,
            epsilon=epsilon, beta_initializer=keras.initializers.constant(beta),
            gamma_initializer=keras.initializers.constant(gamma))
        norm.build(x.shape)

        # pylint: disable=cell-var-from-loop
        def forward_fn(x, beta, gamma):
          # We must monkey-patch the attributes of `norm` with the function
          # arguments, so that the gradient checker will properly compute their
          # gradients. The gradient checker computes gradients with respect to
          # the input arguments of `f`.
          with test.mock.patch.object(norm, 'beta', beta):
            with test.mock.patch.object(norm, 'gamma', gamma):
              return norm(x)
        # pylint: enable=cell-var-from-loop
        results = gradient_checker_v2.compute_gradient(
            forward_fn, [keras.backend.cast(x, dtype), norm.beta, norm.gamma])
        ([x_grad_t, beta_grad_t, gamma_grad_t],
         [x_grad_n, beta_grad_n, gamma_grad_n]) = results

        if dtype == 'float64':
          # We use the float64 numeric gradients as the reference, to compare
          # against the symbolic gradients for all dtypes.
          x_grad_ref = x_grad_n
          beta_grad_ref = beta_grad_n
          gamma_grad_ref = gamma_grad_n
          tol = fp64_tol
        elif dtype == 'float32':
          tol = fp32_tol
        else:
          assert dtype == 'float16'
          tol = fp16_tol

        # We use absolute tolerances in addition to relative tolerances, because
        # some of the values are very close to zero.
        self.assertAllClose(x_grad_t, x_grad_ref, rtol=tol, atol=tol)
        self.assertAllClose(beta_grad_t, beta_grad_ref, rtol=tol, atol=tol)
        self.assertAllClose(gamma_grad_t, gamma_grad_ref, rtol=tol, atol=tol)

  # The gradient_checker_v2 does not work properly with LayerNorm in graph mode.
  @testing_utils.run_v2_only
  def test_backward(self):
    # For numeric stability, we ensure the axis's dimension(s) have at least 4
    # elements.
    self._test_backward_pass((4, 3), (0,))
    self._test_backward_pass((2, 4, 2), (1,))
    self._test_backward_pass((2, 3, 4), (2,))
    self._test_backward_pass((2, 3, 2), (0, 2), fp64_tol=5e-4, fp32_tol=5e-4)
    self._test_backward_pass((2, 2, 2, 2), (1, 3))
    self._test_backward_pass((2, 2, 2, 2), (2, 3))
Beispiel #26
0
class TestWholeModelSaving(keras_parameterized.TestCase):

  def _save_model_dir(self, dirname='saved_model'):
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    return os.path.join(temp_dir, dirname)

  def _assert_same_weights_and_metrics(self, model, loaded_model):
    """Checks that the loaded weights and metrics are the same as the original.

    Args:
      model: original model
      loaded_model: loaded model
    """
    self.assertAllClose(model.weights, loaded_model.weights)

    if loaded_model.optimizer:
      if testing_utils.get_save_format() == 'tf':
        # TODO(b/153110928): Keras TF format doesn't restore optimizer weights
        # currently.
        return
      self.assertAllClose(model.optimizer.weights,
                          loaded_model.optimizer.weights)

    # In V1/Graph mode, the model isn't built, so the metrics are not loaded
    # immediately (requires model to be called on some data before building
    # metrics).
    check_metrics = tf2.enabled() and context.executing_eagerly()

    if check_metrics:
      self.assertAllEqual([m.name for m in model.metrics],
                          [m.name for m in loaded_model.metrics])

  @keras_parameterized.run_with_all_model_types
  @keras_parameterized.run_all_keras_modes
  def test_save_and_load(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
      return  # HDF5 format currently does not allow saving classed models.

    with self.cached_session():
      model = testing_utils.get_model_from_layers(
          [keras.layers.Dense(2),
           keras.layers.RepeatVector(3),
           keras.layers.TimeDistributed(keras.layers.Dense(3))],
          input_shape=(3,))
      model.compile(
          loss=keras.losses.MSE,
          optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
          metrics=[
              keras.metrics.categorical_accuracy,
              keras.metrics.CategoricalCrossentropy(
                  name='cce', label_smoothing=constant_op.constant(0.2)),
          ],
          weighted_metrics=[
              keras.metrics.categorical_crossentropy,
              keras.metrics.CategoricalCrossentropy(
                  name='cce', label_smoothing=constant_op.constant(0.2)),
          ],
          sample_weight_mode='temporal')

      x = np.random.random((1, 3))
      y = np.random.random((1, 3, 3))
      model.train_on_batch(x, y)

      out = model.predict(x)
      keras.models.save_model(model, saved_model_dir, save_format=save_format)

      loaded_model = keras.models.load_model(saved_model_dir)
      self._assert_same_weights_and_metrics(model, loaded_model)

      out2 = loaded_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

      eval_out = model.evaluate(x, y)
      eval_out2 = loaded_model.evaluate(x, y)
      self.assertArrayNear(eval_out, eval_out2, 0.001)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_sequential_model_saving_without_input_shape(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
      model.compile(
          loss=keras.losses.MSE,
          optimizer='rmsprop',
          metrics=[
              keras.metrics.categorical_accuracy,
              keras.metrics.CategoricalAccuracy(name='cat_acc')
          ],
          weighted_metrics=[
              keras.metrics.categorical_accuracy,
              keras.metrics.CategoricalAccuracy(name='cat_acc2')
          ],
          sample_weight_mode='temporal')
      x = np.random.random((1, 3))
      y = np.random.random((1, 3, 3))
      model.train_on_batch(x, y)

      out = model.predict(x)
      model.save(saved_model_dir, save_format=save_format)

      new_model = keras.models.load_model(saved_model_dir)

      self._assert_same_weights_and_metrics(model, new_model)

      out2 = new_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_sequential_model_saving_without_compile(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))

      x = np.random.random((1, 3))
      out = model.predict(x)

      # Save the model without any compilation or training.
      keras.models.save_model(model, saved_model_dir, save_format=save_format)

      new_model = keras.models.load_model(saved_model_dir)
      self._assert_same_weights_and_metrics(model, new_model)

      out2 = new_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  def test_sequential_model_saving_2(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    with ops.Graph().as_default(), self.cached_session():
      # test with custom optimizer, loss

      class CustomOp(optimizer_v1.RMSprop):
        pass

      def custom_loss(y_true, y_pred):
        return keras.losses.mse(y_true, y_pred)

      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.Dense(3))
      model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])

      x = np.random.random((1, 3))
      y = np.random.random((1, 3))
      model.train_on_batch(x, y)

      out = model.predict(x)
      keras.models.save_model(model, saved_model_dir, save_format=save_format)

      new_model = keras.models.load_model(
          saved_model_dir,
          custom_objects={'CustomOp': CustomOp,
                          'custom_loss': custom_loss})
      self._assert_same_weights_and_metrics(model, new_model)

      out2 = new_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  def test_saving_without_compilation(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(2, input_shape=(3,)))
    model.add(keras.layers.Dense(3))
    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])

    keras.models.save_model(model, saved_model_dir, save_format=save_format)
    model = keras.models.load_model(saved_model_dir)

  def test_saving_with_tf_optimizer(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    model = keras.models.Sequential()
    model.add(keras.layers.Dense(2, input_shape=(3,)))
    model.add(keras.layers.Dense(3))
    model.compile(loss='mse',
                  optimizer=training_module.AdadeltaOptimizer(0.1),
                  metrics=['acc'])

    keras.models.save_model(model, saved_model_dir, save_format=save_format)
    model = keras.models.load_model(saved_model_dir)

  def test_saving_right_after_compilation(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.Dense(3))
      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
      if not ops.executing_eagerly_outside_functions():
        model._make_train_function()
      keras.models.save_model(model, saved_model_dir, save_format=save_format)
      model = keras.models.load_model(saved_model_dir)

  def test_saving_lambda_numpy_array_arguments(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    if h5py is None:
      self.skipTest('h5py required to run this test')

    mean = np.random.random((4, 2, 3))
    std = np.abs(np.random.random((4, 2, 3))) + 1e-5
    inputs = keras.layers.Input(shape=(4, 2, 3))
    output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
                                 arguments={'mu': mean, 'std': std})(inputs)
    model = keras.models.Model(inputs, output)
    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])

    keras.models.save_model(model, saved_model_dir, save_format=save_format)

    model = keras.models.load_model(saved_model_dir)

    self.assertAllClose(mean, model.layers[1].arguments['mu'])
    self.assertAllClose(std, model.layers[1].arguments['std'])

  def test_saving_model_with_long_layer_names(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    with self.cached_session():
      # This layer name will make the `layers_name` HDF5 attribute blow
      # out of proportion. Note that it fits into the internal HDF5
      # attribute memory limit on its own but because h5py converts
      # the list of layer names into numpy array, which uses the same
      # amount of memory for every item, it increases the memory
      # requirements substantially.
      x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
      f = x
      for i in range(4):
        f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
      model = keras.Model(inputs=[x], outputs=[f])
      model.compile(
          'adam', loss=keras.losses.MeanSquaredError(), metrics=['acc'])

      x = np.random.random((1, 2))
      y = np.random.random((1, 2))
      model.train_on_batch(x, y)
      out = model.predict(x)

      keras.models.save_model(model, saved_model_dir, save_format=save_format)
      model = keras.models.load_model(saved_model_dir)

      if save_format in ['tf', 'tensorflow']:
        return
      # Check that the HDF5 files contains chunked array
      # of layer names.
      with h5py.File(saved_model_dir, 'r') as h5file:
        num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
                                if attr.startswith('layer_names')])
      # The chunking of layer names array should have happened.
      self.assertGreater(num_names_arrays, 0)
      out2 = model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  def test_saving_model_with_long_weights_names(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    with self.cached_session():
      x = keras.Input(shape=(2,), name='nested_model_input')
      f = x
      for i in range(4):
        f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
      # This layer name will make the `weights_name`
      # HDF5 attribute blow out of proportion.
      f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)
      nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')

      x = keras.Input(shape=(2,), name='outer_model_input')
      f = nested_model(x)
      f = keras.layers.Dense(2, name='outer_model_output')(f)

      model = keras.Model(inputs=[x], outputs=[f])
      model.compile(loss='mse', optimizer='adam', metrics=['acc'])

      x = np.random.random((1, 2))
      y = np.random.random((1, 2))
      model.train_on_batch(x, y)
      out = model.predict(x)

      keras.models.save_model(model, saved_model_dir, save_format=save_format)
      model = keras.models.load_model(saved_model_dir)

      if save_format in ['h5', 'hdf5', 'keras']:
        # Check that the HDF5 files contains chunked array
        # of weight names.
        with h5py.File(saved_model_dir, 'r') as h5file:
          num_weight_arrays = len(
              [attr for attr in h5file['model_weights']['nested_model'].attrs
               if attr.startswith('weight_names')])
        # The chunking of layer names array should have happened.
        self.assertGreater(num_weight_arrays, 0)
      out2 = model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  def test_model_saving_to_pre_created_h5py_file(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    with ops.Graph().as_default(), self.cached_session():
      inputs = keras.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      outputs = keras.layers.Dense(3)(x)

      model = keras.Model(inputs, outputs)
      model.compile(
          loss=keras.losses.MSE,
          optimizer=optimizer_v1.Adam(),
          metrics=[
              keras.metrics.categorical_accuracy,
              keras.metrics.CategoricalAccuracy()
          ])
      x = np.random.random((1, 3))
      y = np.random.random((1, 3))
      model.train_on_batch(x, y)

      out = model.predict(x)

      keras.models.save_model(model, saved_model_dir, save_format=save_format)
      loaded_model = keras.models.load_model(saved_model_dir)
      out1 = loaded_model.predict(x)
      self.assertAllClose(out, out1, atol=1e-05)
      if save_format in ['tf', 'tensorflow']:
        return

      # Test h5 format specifically
      fd, fname = tempfile.mkstemp('.h5')
      with h5py.File(fname, mode='r+') as h5file:
        keras.models.save_model(model, h5file)
        loaded_model = keras.models.load_model(h5file)
        out2 = loaded_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

      # Test non-default options in h5
      with h5py.File('_', driver='core',
                     backing_store=False) as h5file:
        keras.models.save_model(model, h5file)
        loaded_model = keras.models.load_model(h5file)
        out2 = loaded_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

      # Cleanup
      os.close(fd)
      os.remove(fname)

  def test_model_saving_to_new_dir_path(self):
    saved_model_dir = os.path.join(self._save_model_dir(), 'newdir',
                                   'saved_model')
    save_format = testing_utils.get_save_format()

    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))

      x = np.random.random((1, 3))
      out = model.predict(x)

      keras.models.save_model(model, saved_model_dir, save_format=save_format)

      new_model = keras.models.load_model(saved_model_dir)
      self._assert_same_weights_and_metrics(model, new_model)

      out2 = new_model.predict(x)
      self.assertAllClose(out, out2, atol=1e-05)

  def test_model_raise_exception_with_failed_saving(self):
    if h5py is None:
      self.skipTest('h5py required to run this test')

    saved_model_dir = self._save_model_dir()
    saved_model_path = os.path.join(saved_model_dir, 'saved_model.h5')

    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))

      with self.assertRaisesRegex(OSError, 'Unable to create file'):
        with h5py.File(saved_model_path, 'w'):
          keras.models.save_model(model, saved_model_path)

  def test_saving_constant_initializer_with_numpy(self):
    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    model = keras.models.Sequential()
    model.add(
        keras.layers.Dense(
            2,
            input_shape=(3,),
            kernel_initializer=keras.initializers.Constant(np.ones((3, 2)))))
    model.add(keras.layers.Dense(3))
    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
    keras.models.save_model(model, saved_model_dir, save_format=save_format)
    model = keras.models.load_model(saved_model_dir)

  def test_saving_group_naming_h5py(self):
    # Test saving model with layer which name is prefix to a previous layer
    # name.

    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir)
    h5_path = os.path.join(temp_dir, 'test.h5')

    input_layer = keras.layers.Input((None, None, 3), name='test_input')
    x = keras.layers.Conv2D(1, 1, name='conv1/conv')(input_layer)
    x = keras.layers.Activation('relu', name='conv1')(x)
    model = keras.models.Model(inputs=input_layer, outputs=x)

    model.save_weights(h5_path)
    model.load_weights(h5_path)

  def test_primitive_attrs_contain_no_extraneous_strings(self):
    if h5py is None:
      self.skipTest('h5py required to run this test')

    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(1, input_shape=[2]))
    model.save(saved_model_dir, save_format=save_format)
    if save_format in ['tf', 'tensorflow']:
      return

    h5file = h5py.File(saved_model_dir, 'r')
    self.assertRegex(h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_functional_model_with_custom_loss_and_metric(self):
    def _make_model():
      inputs = keras.Input(shape=(4,))
      x = keras.layers.Dense(8, activation='relu')(inputs)
      outputs = keras.layers.Dense(3, activation='softmax')(x)
      model = keras.Model(inputs=inputs, outputs=outputs)
      custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
      model.add_loss(custom_loss)
      model.add_metric(custom_loss, aggregation='mean', name='custom_loss')
      return model

    saved_model_dir = self._save_model_dir()
    save_format = testing_utils.get_save_format()

    model = _make_model()
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(),
        optimizer=optimizers.gradient_descent_v2.SGD(),
        metrics=[keras.metrics.SparseCategoricalCrossentropy()])
    x = np.random.normal(size=(32, 4))
    y = np.random.randint(0, 3, size=32)
    model.train_on_batch(x, y)
    evaluation_results = model.evaluate(x, y)
    # Save and reload model.
    model.save(saved_model_dir, save_format=save_format)
    del model  # Prevent misuse.
    loaded_model = keras.models.load_model(saved_model_dir)
    loaded_model_eval_results = loaded_model.evaluate(x, y)
    # Assert all evaluation results are the same.
    self.assertAllClose(evaluation_results, loaded_model_eval_results, 1e-9)
    # Check correctness of the loss calculation.
    self.assertAllGreater(evaluation_results, 0.)
    evaluation_results = dict(
        zip(loaded_model.metrics_names, evaluation_results))
    self.assertNear(
        evaluation_results['sparse_categorical_crossentropy'] +
        evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)

  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
  def test_save_uncompiled_model_with_optimizer(self):
    with self.cached_session() as session:
      saved_model_dir = self._save_model_dir()
      save_format = testing_utils.get_save_format()
      model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))])
      # Set the model's optimizer but don't compile. This can happen if the
      # model is trained with a custom training loop.
      model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001)
      if not context.executing_eagerly():
        session.run([v.initializer for v in model.variables])
      model.save(saved_model_dir, save_format=save_format)

      if save_format in ['tf', 'tensorflow']:
        loaded = keras.models.load_model(saved_model_dir)
        self.assertIsInstance(loaded.optimizer,
                              keras.optimizer_v2.optimizer_v2.OptimizerV2)

  @combinations.generate(combinations.combine(mode=['eager']))
  def test_functional_model_with_getitem_op_layer(self):
    inp = keras.Input(shape=(8))

    out = inp[:]
    model = keras.Model(
        inputs=[inp],
        outputs=out)
    batch_size = 7
    x = array_ops.stack([
        math_ops.range(8) for _ in range(batch_size)])
    args = [x]
    expected = x[:]

    self.assertAllEqual(model(args), expected)
    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)

    # Make sure it can be successfully saved and loaded
    save_format = testing_utils.get_save_format()
    saved_model_dir = self._save_model_dir()
    keras.models.save_model(model, saved_model_dir, save_format=save_format)

    loaded_model = keras.models.load_model(saved_model_dir)

    self.assertAllEqual(loaded_model(args), expected)
    self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
                        expected)
Beispiel #27
0
        opt = rmsprop.RMSprop(learning_rate=1., momentum=0.2, centered=False)
        opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
        # There should be iteration, and two unique slot variables for v1 and v2.
        self.assertLen(set({id(v) for v in opt.variables()}), 5)
        self.assertEqual(self.evaluate(opt.variables()[0]),
                         self.evaluate(opt.iterations))

        opt = rmsprop.RMSprop(learning_rate=1., momentum=0.2, centered=True)
        opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
        # There should be iteration, and three unique slot variables for v1 and v2
        self.assertLen(set({id(v) for v in opt.variables()}), 7)
        self.assertEqual(self.evaluate(opt.variables()[0]),
                         self.evaluate(opt.iterations))


@combinations.generate(combinations.combine(mode=["graph", "eager"]))
class SlotColocationTest(test.TestCase, parameterized.TestCase):
    @parameterized.parameters([True, False])
    @test_util.run_gpu_only
    def testRunMinimizeOnGPUForCPUVariables(self, use_resource):
        with ops.device("/device:CPU:0"):
            if use_resource:
                var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
                var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
            else:
                var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
                var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)

        def loss():
            return 5 * var0 + 3 * var1
Beispiel #28
0
class AdagradOptimizerTest(test.TestCase, parameterized.TestCase):
    def doTestBasic(self, use_callable_params=False):
        for dtype in _DATA_TYPES:
            var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
            var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
            grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
            var0 = variables.Variable(var0_np)
            var1 = variables.Variable(var1_np)
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)

            learning_rate = lambda: 3.0
            if not use_callable_params:
                learning_rate = learning_rate()

            ada_opt = adagrad.Adagrad(learning_rate)

            accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)

            if not context.executing_eagerly():
                ada_update = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

            # Fetch params to validate initial values
            v0_val, v1_val = self.evaluate([var0, var1])
            self.assertAllClose([1.0, 2.0], v0_val)
            self.assertAllClose([3.0, 4.0], v1_val)

            # Run 3 steps of adagrad
            for _ in range(3):
                if not context.executing_eagerly():
                    self.evaluate(ada_update)
                else:
                    ada_opt.apply_gradients(zip([grads0, grads1],
                                                [var0, var1]))
                var0_np, accum0_np = adagrad_update_numpy(
                    var0_np, accum0_np, grads0_np, 3.0)
                var1_np, accum1_np = adagrad_update_numpy(
                    var1_np, accum1_np, grads1_np, 3.0)
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasic(self):
        self.doTestBasic()

    @combinations.generate(combinations.combine(mode=["eager"]))
    def testBasicCallableParams(self):
        self.doTestBasic(use_callable_params=True)

    def testBasicWithLearningRateDecay(self):
        for dtype in _DATA_TYPES:
            var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
            var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
            grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
            var0 = variables.Variable(var0_np)
            var1 = variables.Variable(var1_np)
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)

            learning_rate = 3.0
            decay = 0.5

            ada_opt = adagrad.Adagrad(learning_rate, decay=decay)

            accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)

            if not context.executing_eagerly():
                ada_update = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

            # Fetch params to validate initial values
            v0_val, v1_val = self.evaluate([var0, var1])
            self.assertAllClose([1.0, 2.0], v0_val)
            self.assertAllClose([3.0, 4.0], v1_val)

            # Run 3 steps of adagrad
            for t in range(3):
                if not context.executing_eagerly():
                    self.evaluate(ada_update)
                else:
                    ada_opt.apply_gradients(zip([grads0, grads1],
                                                [var0, var1]))
                lr_np = learning_rate / (1 + decay * t)
                var0_np, accum0_np = adagrad_update_numpy(
                    var0_np, accum0_np, grads0_np, lr_np)
                var1_np, accum1_np = adagrad_update_numpy(
                    var1_np, accum1_np, grads1_np, lr_np)
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    def testBasicWithLargeEpsilon(self):
        var0_np = np.array([1.0, 2.0])
        var1_np = np.array([3.0, 4.0])
        grads0_np = np.array([0.1, 0.1])
        grads1_np = np.array([0.01, 0.01])
        var0 = variables.Variable(var0_np)
        var1 = variables.Variable(var1_np)
        grads0 = constant_op.constant(grads0_np)
        grads1 = constant_op.constant(grads1_np)

        learning_rate = 3.0

        ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.0)

        accum0_np = np.array([0.1, 0.1])
        accum1_np = np.array([0.1, 0.1])

        if not context.executing_eagerly():
            ada_update = ada_opt.apply_gradients(
                zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())

        # Fetch params to validate initial values
        v0_val, v1_val = self.evaluate([var0, var1])
        self.assertAllClose([1.0, 2.0], v0_val)
        self.assertAllClose([3.0, 4.0], v1_val)

        # Run 3 steps of adagrad
        for _ in range(3):
            if not context.executing_eagerly():
                self.evaluate(ada_update)
            else:
                ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
            var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
                                                      grads0_np, 3.0, 1.0)
            var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
                                                      grads1_np, 3.0, 1.0)
            self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
            self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))

    def testBasicWithLearningRateInverseTimeDecay(self):
        for dtype in _DATA_TYPES:
            var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
            var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
            grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
            var0 = variables.Variable(var0_np)
            var1 = variables.Variable(var1_np)
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)

            learning_rate = 3.0
            decay = 0.5
            lr_schedule = learning_rate_schedule.InverseTimeDecay(
                learning_rate, decay_steps=1.0, decay_rate=decay)

            ada_opt = adagrad.Adagrad(lr_schedule)

            accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)

            if not context.executing_eagerly():
                ada_update = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

            # Fetch params to validate initial values
            v0_val, v1_val = self.evaluate([var0, var1])
            self.assertAllClose([1.0, 2.0], v0_val)
            self.assertAllClose([3.0, 4.0], v1_val)

            # Run 3 steps of adagrad
            for t in range(3):
                if not context.executing_eagerly():
                    self.evaluate(ada_update)
                else:
                    ada_opt.apply_gradients(zip([grads0, grads1],
                                                [var0, var1]))
                lr_np = learning_rate / (1 + decay * t)
                var0_np, accum0_np = adagrad_update_numpy(
                    var0_np, accum0_np, grads0_np, lr_np)
                var1_np, accum1_np = adagrad_update_numpy(
                    var1_np, accum1_np, grads1_np, lr_np)
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    def testMinimizeSparseResourceVariable(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0 = variables.Variable([[1.0, 2.0], [3.0, 4.0]],
                                          dtype=dtype)
                x = constant_op.constant([[4.0], [5.0]], dtype=dtype)

                def loss():
                    pred = math_ops.matmul(
                        embedding_ops.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
                    return pred * pred

                sgd_op = adagrad.Adagrad(1.0).minimize(loss, var_list=[var0])
                self.evaluate(variables.global_variables_initializer())
                # Fetch params to validate initial values
                self.assertAllCloseAccordingToType([[1.0, 2.0], [3.0, 4.0]],
                                                   self.evaluate(var0))
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[0, 1], [3, 4]],
                                                   self.evaluate(var0),
                                                   atol=0.01)

    def testTensorLearningRate(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
                var0 = variables.Variable(var0_np)
                var1 = variables.Variable(var1_np)
                grads0 = constant_op.constant(grads0_np)
                grads1 = constant_op.constant(grads1_np)

                learning_rate = constant_op.constant(3.0)
                ada_opt = adagrad.Adagrad(learning_rate)
                ada_update = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())
                # Fetch params to validate initial values
                self.assertAllClose([1.0, 2.0], self.evaluate(var0))
                self.assertAllClose([3.0, 4.0], self.evaluate(var1))
                accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                # Run 3 steps of adagrad
                for _ in range(3):
                    self.evaluate(ada_update)
                    var0_np, accum0_np = adagrad_update_numpy(
                        var0_np, accum0_np, grads0_np, learning_rate)
                    var1_np, accum1_np = adagrad_update_numpy(
                        var1_np, accum1_np, grads1_np, learning_rate)
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))
                    self.assertAllCloseAccordingToType(var1_np,
                                                       self.evaluate(var1))

    def testSparseBasic(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0, 0.01],
                                     dtype=dtype.as_numpy_dtype)

                var0 = variables.Variable(var0_np)
                var1 = variables.Variable(var1_np)
                grads0_np_indices = np.array([0, 2], dtype=np.int32)
                grads0 = ops.IndexedSlices(
                    constant_op.constant(grads0_np[grads0_np_indices]),
                    constant_op.constant(grads0_np_indices),
                    constant_op.constant([3]))
                grads1_np_indices = np.array([0, 2], dtype=np.int32)
                grads1 = ops.IndexedSlices(
                    constant_op.constant(grads1_np[grads1_np_indices]),
                    constant_op.constant(grads1_np_indices),
                    constant_op.constant([3]))
                learning_rate = 3.0
                ada_opt = adagrad.Adagrad(learning_rate)
                ada_update = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())

                # Fetch params to validate initial values
                self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0))
                self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1))

                accum0_np = np.array([0.1, 0.1, 0.1],
                                     dtype=dtype.as_numpy_dtype)
                accum1_np = np.array([0.1, 0.1, 0.1],
                                     dtype=dtype.as_numpy_dtype)

                # Run 3 step of sgd
                for _ in range(3):
                    self.evaluate(ada_update)

                    var0_np, accum0_np = sparse_adagrad_update_numpy(
                        var0_np, accum0_np, grads0_np_indices,
                        grads0_np[grads0_np_indices], learning_rate)
                    var1_np, accum1_np = sparse_adagrad_update_numpy(
                        var1_np, accum1_np, grads1_np_indices,
                        grads1_np[grads1_np_indices], learning_rate)
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))
                    self.assertAllCloseAccordingToType(var1_np,
                                                       self.evaluate(var1))

    def testSparseSingleVarDim(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0_np = np.array([1.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)

                var0 = variables.Variable(var0_np)
                grads0_np_indices = np.array([0], dtype=np.int32)
                grads0 = ops.IndexedSlices(
                    constant_op.constant(grads0_np[grads0_np_indices]),
                    constant_op.constant(grads0_np_indices),
                    constant_op.constant([3]))
                learning_rate = 3.0
                ada_opt = adagrad.Adagrad(learning_rate, epsilon=1.)
                ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
                self.evaluate(variables.global_variables_initializer())

                # Fetch params to validate initial values
                self.assertAllClose([1.0], self.evaluate(var0))

                accum0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)

                # Run 3 step of sgd
                for _ in range(3):
                    self.evaluate(ada_update)

                    var0_np, accum0_np = sparse_adagrad_update_numpy(
                        var0_np,
                        accum0_np,
                        grads0_np_indices,
                        grads0_np[grads0_np_indices],
                        learning_rate,
                        epsilon=1.)
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))

    def testSparseRepeatedIndices(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var_np = np.array([[1.0], [2.0]], dtype=dtype.as_numpy_dtype)

                repeated_index_update_var = variables.Variable(var_np,
                                                               dtype=dtype)
                aggregated_update_var = variables.Variable(var_np, dtype=dtype)
                grad_repeated_index = ops.IndexedSlices(
                    constant_op.constant([0.1, 0.1], shape=[2, 1],
                                         dtype=dtype),
                    constant_op.constant([1, 1]), constant_op.constant([2, 1]))
                grad_aggregated = ops.IndexedSlices(
                    constant_op.constant([0.2], shape=[1, 1], dtype=dtype),
                    constant_op.constant([1]), constant_op.constant([2, 1]))
                repeated_update = adagrad.Adagrad(3.0).apply_gradients([
                    (grad_repeated_index, repeated_index_update_var)
                ])
                aggregated_update = adagrad.Adagrad(3.0).apply_gradients([
                    (grad_aggregated, aggregated_update_var)
                ])
                self.evaluate(variables.global_variables_initializer())
                self.assertAllClose(self.evaluate(aggregated_update_var),
                                    self.evaluate(repeated_index_update_var))
                for _ in range(3):
                    self.evaluate(repeated_update)
                    self.evaluate(aggregated_update)
                    self.assertAllClose(
                        self.evaluate(aggregated_update_var),
                        self.evaluate(repeated_index_update_var))

    def testSparseRepeatedIndicesByEmbeddingLookUp(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var_repeated = variables.Variable([1.0, 2.0], dtype=dtype)
                loss_repeated = lambda: math_ops.reduce_sum(  # pylint: disable=g-long-lambda
                    embedding_ops.embedding_lookup(var_repeated, [0, 0]))  # pylint: disable=cell-var-from-loop
                var_aggregated = variables.Variable([1.0, 2.0], dtype=dtype)
                loss_aggregated = lambda: 2 * math_ops.reduce_sum(  # pylint: disable=g-long-lambda
                    embedding_ops.embedding_lookup(var_aggregated, [0]))  # pylint: disable=cell-var-from-loop
                update_op_repeated = adagrad.Adagrad(2.0).minimize(
                    loss_repeated, var_list=[var_repeated])
                update_op_aggregated = adagrad.Adagrad(2.0).minimize(
                    loss_aggregated, var_list=[var_aggregated])
                self.evaluate(variables.global_variables_initializer())
                self.assertAllCloseAccordingToType(
                    self.evaluate(var_repeated), self.evaluate(var_aggregated))
                for _ in range(3):
                    self.evaluate(update_op_repeated)
                    self.evaluate(update_op_aggregated)
                    self.assertAllCloseAccordingToType(
                        self.evaluate(var_repeated),
                        self.evaluate(var_aggregated))

    def testSparseStability(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in [dtypes.half]:
                shape = [1, 6]
                var0_np = np.array([[
                    0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
                    -0.0105945
                ]],
                                   dtype=dtype.as_numpy_dtype)
                var0 = variables.Variable(var0_np)
                grads0_np = np.array([[
                    -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
                    -8.4877e-05, -9.48906e-05
                ]],
                                     dtype=dtype.as_numpy_dtype)
                grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
                                           constant_op.constant([0]),
                                           constant_op.constant(shape))
                ada_opt = adagrad.Adagrad(1.0)
                ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
                slot0 = ada_opt.get_slot(var0, "accumulator")
                init = variables.global_variables_initializer()
                for _ in range(100):
                    self.evaluate(init)
                    self.evaluate(ada_update)
                    self.assertAllCloseAccordingToType(
                        np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]),
                        self.evaluate(slot0))
                    self.assertAllCloseAccordingToType(
                        np.array([[
                            0.00891194, -0.10712013, 0.11047515, 0.22636929,
                            -0.0144573, -0.01029443
                        ]]), self.evaluate(var0))

    def testSharing(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in _DATA_TYPES:
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

                var0 = variables.Variable(var0_np)
                var1 = variables.Variable(var1_np)
                grads0 = constant_op.constant(grads0_np)
                grads1 = constant_op.constant(grads1_np)

                learning_rate = 3.0
                ada_opt = adagrad.Adagrad(learning_rate)
                # Apply the optimizer twice.  Both applications will use
                # the same accums.
                ada_update1 = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                ada_update2 = ada_opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                slot0 = ada_opt.get_slot(var0, "accumulator")
                self.assertEqual(slot0.shape, var0.shape)
                slot1 = ada_opt.get_slot(var1, "accumulator")
                self.assertEqual(slot1.shape, var1.shape)
                self.evaluate(variables.global_variables_initializer())

                # Fetch params to validate initial values.
                self.assertAllClose([1.0, 2.0], self.evaluate(var0))
                self.assertAllClose([3.0, 4.0], self.evaluate(var1))
                # Mix the first and the second adagrad for 3 steps.
                self.evaluate(ada_update1)
                self.evaluate(ada_update2)
                self.evaluate(ada_update1)

                accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                for _ in range(3):
                    var0_np, accum0_np = adagrad_update_numpy(
                        var0_np, accum0_np, grads0_np, learning_rate)
                    var1_np, accum1_np = adagrad_update_numpy(
                        var1_np, accum1_np, grads1_np, learning_rate)
                self.assertAllCloseAccordingToType(var0_np,
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType(var1_np,
                                                   self.evaluate(var1))

    def testConstructAdagradWithLR(self):
        opt = adagrad.Adagrad(lr=1.0)
        opt_2 = adagrad.Adagrad(learning_rate=0.1, lr=1.0)
        opt_3 = adagrad.Adagrad(learning_rate=0.1)
        self.assertIsInstance(opt.lr, variables.Variable)
        self.assertIsInstance(opt_2.lr, variables.Variable)
        self.assertIsInstance(opt_3.lr, variables.Variable)

        self.evaluate(variables.global_variables_initializer())
        self.assertAllClose(self.evaluate(opt.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
Beispiel #29
0
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


def _maybe_serialized(lr_decay, serialize_and_deserialize):
    if serialize_and_deserialize:
        serialized = learning_rate_schedule.serialize(lr_decay)
        return learning_rate_schedule.deserialize(serialized)
    else:
        return lr_decay


@combinations.generate(
    combinations.combine(serialize=[False, True], mode=["graph", "eager"]))
class LRDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
    def testContinuous(self, serialize):
        self.evaluate(variables.global_variables_initializer())
        step = 5
        decayed_lr = learning_rate_schedule.ExponentialDecay(0.05, 10, 0.96)
        decayed_lr = _maybe_serialized(decayed_lr, serialize)
        expected = .05 * 0.96**(5.0 / 10.0)
        self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)

    def testStaircase(self, serialize):
        if context.executing_eagerly():
            step = variables.Variable(0)
            self.evaluate(variables.global_variables_initializer())
            decayed_lr = learning_rate_schedule.ExponentialDecay(
                .1, 3, 0.96, staircase=True)
class GradientDescentOptimizerTest(test.TestCase, parameterized.TestCase):
    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasic(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
                                                          dtype=dtype)
            var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
                                                          dtype=dtype)
            grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
            grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
            sgd = gradient_descent.SGD(3.0)
            sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())
            # Run 1 step of sgd
            self.evaluate(sgd_op)
            # Validate updated params
            self.assertAllCloseAccordingToType(
                [1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], self.evaluate(var0))
            self.assertAllCloseAccordingToType(
                [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1))

    def _test_basic_sgd_with_learning_rate_decay(self, sgd, dtype):
        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
        if not context.executing_eagerly():
            sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
        self.evaluate(variables.global_variables_initializer())
        # Run 2 steps of sgd
        if not context.executing_eagerly():
            self.evaluate(sgd_op)
        else:
            sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
        # Validate updated params
        self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
                                           self.evaluate(var0))
        self.assertAllCloseAccordingToType(
            [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1))

        if not context.executing_eagerly():
            self.evaluate(sgd_op)
        else:
            sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
        # Validate updated params
        self.assertAllCloseAccordingToType(
            [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1],
            self.evaluate(var0))
        self.assertAllCloseAccordingToType(
            [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01],
            self.evaluate(var1))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasicWithLearningRateDecay(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            learning_rate = 3.0
            decay = 0.5
            sgd = gradient_descent.SGD(learning_rate=learning_rate,
                                       decay=decay)
            self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasicWithLearningRateInverseTimeDecay(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            learning_rate = learning_rate_schedule.InverseTimeDecay(
                3.0, decay_steps=1.0, decay_rate=0.5)
            sgd = gradient_descent.SGD(learning_rate=learning_rate)
            self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasicWithLearningRateInverseTimeDecaySerializeAndDeserialize(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            learning_rate = learning_rate_schedule.InverseTimeDecay(
                3.0, decay_steps=1.0, decay_rate=0.5)
            sgd = gradient_descent.SGD(learning_rate=learning_rate)
            sgd = gradient_descent.SGD.from_config(sgd.get_config())
            self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testBasicCallableParams(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
                                                          dtype=dtype)
            var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
                                                          dtype=dtype)
            grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
            grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
            lr = lambda: 3.0
            sgd = gradient_descent.SGD(lr)
            sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())
            # Run 1 step of sgd
            self.evaluate(sgd_op)
            # Validate updated params
            self.assertAllCloseAccordingToType(
                [1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], self.evaluate(var0))
            self.assertAllCloseAccordingToType(
                [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1))

    @combinations.generate(combinations.combine(mode=["graph", "eager"]))
    def testMinimizeResourceVariable(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]],
                                                          dtype=dtype)
            var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
            x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
            loss = lambda: math_ops.matmul(var0, x) + var1  # pylint: disable=cell-var-from-loop
            sgd = gradient_descent.SGD(1.0)
            sgd_op = sgd.minimize(loss, [var0, var1])
            self.evaluate(variables.global_variables_initializer())
            # Run 1 step of sgd
            self.evaluate(sgd_op)
            # Validate updated params
            self.assertAllCloseAccordingToType([[1.0 - 4.0, 2.0 - 5.0]],
                                               self.evaluate(var0))
            self.assertAllCloseAccordingToType([3.0 - 1.0],
                                               self.evaluate(var1))

    def testMinimizeSparseResourceVariable(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
                var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]],
                                                              dtype=dtype)
                var1 = resource_variable_ops.ResourceVariable([3.0],
                                                              dtype=dtype)
                x = constant_op.constant([[4.0], [5.0]], dtype=dtype)

                def loss():
                    pred = math_ops.matmul(
                        embedding_ops.embedding_lookup([var0], [0]), x)  # pylint: disable=cell-var-from-loop
                    pred += var1  # pylint: disable=cell-var-from-loop
                    return pred * pred

                sgd_op = gradient_descent.SGD(1.0).minimize(loss, [var0, var1])
                self.evaluate(variables.global_variables_initializer())
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
                np_grad = 2 * np_pred
                self.assertAllCloseAccordingToType(
                    [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]],
                    self.evaluate(var0))
                self.assertAllCloseAccordingToType([3.0 - np_grad],
                                                   self.evaluate(var1))

    def testTensorLearningRate(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            var0 = variables.Variable([1.0, 2.0], dtype=dtype)
            var1 = variables.Variable([3.0, 4.0], dtype=dtype)
            grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
            grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
            lrate = constant_op.constant(3.0)
            sgd_op = gradient_descent.SGD(lrate).apply_gradients(
                zip([grads0, grads1], [var0, var1]))
            self.evaluate(variables.global_variables_initializer())
            # Run 1 step of sgd
            self.evaluate(sgd_op)
            # Validate updated params
            self.assertAllCloseAccordingToType(
                [1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], self.evaluate(var0))
            self.assertAllCloseAccordingToType(
                [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1))

    def testGradWrtRef(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
                opt = gradient_descent.SGD(3.0)
                values = [1.0, 3.0]
                vars_ = [variables.Variable([v], dtype=dtype) for v in values]
                loss = lambda: vars_[0] + vars_[1]  # pylint: disable=cell-var-from-loop
                grads_and_vars = opt._compute_gradients(loss, vars_)
                self.evaluate(variables.global_variables_initializer())
                for grad, _ in grads_and_vars:
                    self.assertAllCloseAccordingToType([1.0],
                                                       self.evaluate(grad))

    def testSparseBasic(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
                var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
                var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
                grads0 = ops.IndexedSlices(
                    constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
                    constant_op.constant([0]), constant_op.constant([2, 1]))
                grads1 = ops.IndexedSlices(
                    constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
                    constant_op.constant([1]), constant_op.constant([2, 1]))
                sgd_op = gradient_descent.SGD(3.0).apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())
                # Run 1 step of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
                                                   self.evaluate(var1))

    def testSparseBasicWithLearningRateDecay(self):
        # TODO(tanzheny, omalleyt): Fix test in eager mode.
        with ops.Graph().as_default():
            for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
                var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
                var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
                grads0 = ops.IndexedSlices(
                    constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
                    constant_op.constant([0]), constant_op.constant([2, 1]))
                grads1 = ops.IndexedSlices(
                    constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
                    constant_op.constant([1]), constant_op.constant([2, 1]))
                sgd_op = gradient_descent.SGD(3.0, decay=0.5).apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                self.evaluate(variables.global_variables_initializer())
                # Run 2 steps of sgd
                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
                                                   self.evaluate(var0))
                self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
                                                   self.evaluate(var1))

                self.evaluate(sgd_op)
                # Validate updated params
                self.assertAllCloseAccordingToType(
                    [[1.0 - 3.0 * 0.1 - 2.0 * 0.1], [2.0]],
                    self.evaluate(var0))
                self.assertAllCloseAccordingToType(
                    [[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]],
                    self.evaluate(var1))

    def testCapturingInDefunWhileExecutingEagerly(self):
        with context.eager_mode():
            optimizer = gradient_descent.SGD(1.0)

            def step():
                self.v = resource_variable_ops.ResourceVariable(1.0)
                with backprop.GradientTape() as tape:
                    loss = self.v**2
                grad = tape.gradient(loss, self.v)
                optimizer.apply_gradients([(grad, self.v)])
                return self.v.read_value()

            compiled_step = function.defun(step)

            self.assertEqual(float(step()), -1.0)
            self.assertEqual(float(compiled_step()), -1.0)
            # This shouldn't fail; in particular, the learning rate tensor should
            # be an EagerTensor once again, not a graph Tensor.
            self.assertEqual(float(step()), -1.0)

    def testConstructSGDWithLR(self):
        opt = gradient_descent.SGD(lr=1.0)
        opt_2 = gradient_descent.SGD(learning_rate=0.1, lr=1.0)
        opt_3 = gradient_descent.SGD(learning_rate=0.1)
        self.assertIsInstance(opt.lr, variables.Variable)
        self.assertIsInstance(opt_2.lr, variables.Variable)
        self.assertIsInstance(opt_3.lr, variables.Variable)

        self.evaluate(variables.global_variables_initializer())
        self.assertAllClose(self.evaluate(opt.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
        self.assertAllClose(self.evaluate(opt_3.lr), (0.1))