Ejemplo n.º 1
0
 def testMultipleOfNoneGradRaisesError(self):
     gradient = tf.constant(self._grad_vec, dtype=tf.float32)
     variable = variables_lib.Variable(tf.zeros_like(gradient))
     grad_to_var = (None, variable)
     gradient_multipliers = {variable: self._multiplier}
     with self.assertRaises(ValueError):
         learning.multiply_gradients(grad_to_var, gradient_multipliers)
Ejemplo n.º 2
0
    def testMultipleGradientsWithVariables(self):
        gradient = tf.constant(self._grad_vec, dtype=tf.float32)
        variable = variables_lib.Variable(tf.zeros_like(gradient))
        grad_to_var = (gradient, variable)
        gradient_multipliers = {variable: self._multiplier}

        [grad_to_var] = learning.multiply_gradients([grad_to_var],
                                                    gradient_multipliers)

        # Ensure the variable passed through.
        self.assertEqual(grad_to_var[1], variable)

        with self.cached_session() as sess:
            actual_gradient = sess.run(grad_to_var[0])
        np_testing.assert_almost_equal(actual_gradient,
                                       self._multiplied_grad_vec, 5)
Ejemplo n.º 3
0
    def testTensorMultiplierOfGradient(self):
        gradient = tf.constant(self._grad_vec, dtype=tf.float32)
        variable = variables_lib.Variable(tf.zeros_like(gradient))
        multiplier_flag = variables_lib.Variable(True)
        tensor_multiplier = tf.where(multiplier_flag, self._multiplier, 1.0)
        grad_to_var = (gradient, variable)
        gradient_multipliers = {variable: tensor_multiplier}

        [grad_to_var] = learning.multiply_gradients([grad_to_var],
                                                    gradient_multipliers)

        with self.cached_session() as sess:
            sess.run(variables_lib.global_variables_initializer())
            gradient_true_flag = sess.run(grad_to_var[0])
            sess.run(multiplier_flag.assign(False))
            gradient_false_flag = sess.run(grad_to_var[0])
        np_testing.assert_almost_equal(gradient_true_flag,
                                       self._multiplied_grad_vec, 5)
        np_testing.assert_almost_equal(gradient_false_flag, self._grad_vec, 5)
Ejemplo n.º 4
0
    def testIndexedSlicesGradIsMultiplied(self):
        values = tf.constant(self._grad_vec, dtype=tf.float32)
        indices = tf.constant([0, 1, 2], dtype=tf.int32)
        dense_shape = tf.constant([self._grad_vec.size], dtype=tf.int32)

        gradient = ops.IndexedSlices(values, indices, dense_shape)
        variable = variables_lib.Variable(tf.zeros((1, 3)))
        grad_to_var = (gradient, variable)
        gradient_multipliers = {variable: self._multiplier}

        [grad_to_var] = learning.multiply_gradients([grad_to_var],
                                                    gradient_multipliers)

        # Ensure the built IndexedSlice has the right form.
        self.assertEqual(grad_to_var[1], variable)
        self.assertEqual(grad_to_var[0].indices, indices)
        self.assertEqual(grad_to_var[0].dense_shape, dense_shape)

        with self.cached_session() as sess:
            actual_gradient = sess.run(grad_to_var[0].values)
        np_testing.assert_almost_equal(actual_gradient,
                                       self._multiplied_grad_vec, 5)
Ejemplo n.º 5
0
 def testNonDictMultiplierRaisesError(self):
     gradient = tf.constant(self._grad_vec, dtype=tf.float32)
     variable = variables_lib.Variable(tf.zeros_like(gradient))
     grad_to_var = (gradient, variable)
     with self.assertRaises(ValueError):
         learning.multiply_gradients([grad_to_var], 3)