コード例 #1
0
  def testNoneGradPassesThroughCorrectly(self):
    gradient = None
    variable = variables_lib.Variable(self._zero_vec, dtype=tf.float32)

    gradients_to_variables = (gradient, variable)
    [gradients_to_variables
    ] = learning.clip_gradient_norms([gradients_to_variables], self._max_norm)

    self.assertEqual(gradients_to_variables[0], None)
    self.assertEqual(gradients_to_variables[1], variable)
コード例 #2
0
  def testOrdinaryGradIsClippedCorrectly(self):
    gradient = tf.constant(self._grad_vec, dtype=tf.float32)
    variable = variables_lib.Variable(self._zero_vec, dtype=tf.float32)
    gradients_to_variables = (gradient, variable)
    [gradients_to_variables
    ] = learning.clip_gradient_norms([gradients_to_variables], self._max_norm)

    # Ensure the variable passed through.
    self.assertEqual(gradients_to_variables[1], variable)

    with self.cached_session() as sess:
      actual_gradient = sess.run(gradients_to_variables[0])
    np_testing.assert_almost_equal(actual_gradient, self._clipped_grad_vec)
コード例 #3
0
  def testIndexedSlicesGradIsClippedCorrectly(self):
    sparse_grad_indices = np.array([0, 1, 4])
    sparse_grad_dense_shape = [self._grad_vec.size]

    values = tf.constant(self._grad_vec, dtype=tf.float32)
    indices = tf.constant(sparse_grad_indices, dtype=tf.int32)
    dense_shape = tf.constant(sparse_grad_dense_shape, dtype=tf.int32)

    gradient = ops.IndexedSlices(values, indices, dense_shape)
    variable = variables_lib.Variable(self._zero_vec, dtype=tf.float32)

    gradients_to_variables = (gradient, variable)
    gradients_to_variables = learning.clip_gradient_norms(
        [gradients_to_variables], self._max_norm)[0]

    # Ensure the built IndexedSlice has the right form.
    self.assertEqual(gradients_to_variables[1], variable)
    self.assertEqual(gradients_to_variables[0].indices, indices)
    self.assertEqual(gradients_to_variables[0].dense_shape, dense_shape)

    with tf.Session() as sess:
      actual_gradient = sess.run(gradients_to_variables[0].values)
    np_testing.assert_almost_equal(actual_gradient, self._clipped_grad_vec)