def testOpScaleDifferentDtypes(self): x_1 = tf.placeholder(tf.float16, shape=()) snt.scale_gradient(x_1, 0.1) # clip_gradient throws here if the Defun func_name does not use the dtype. x_2 = tf.placeholder(tf.float32, shape=()) snt.scale_gradient(x_2, 0.1)
def testTwoOps(self): """Tests that the op can be instantiated twice with appropriate results. Implementations with inappropriate global registration of gradients will fail this test. """ x = tf.placeholder(tf.float32, [1]) y = x * x y = snt.scale_gradient(y, 0.1) y = snt.scale_gradient(y, 0.1) dydx = tf.gradients([y], [x])[0] with self.test_session() as sess: dydx_, y_ = sess.run([dydx, y], feed_dict={x: [3.0]}) self.assertAlmostEqual(dydx_[0], 2 * 0.1**2 * 3.0, places=6) self.assertAlmostEqual(y_[0], 3.0 ** 2, places=6)
def testOpScale(self, x_, scale): x = tf.placeholder(tf.float32, [1]) y = x * x y = snt.scale_gradient(y, scale) dydx = tf.gradients([y], [x])[0] if scale == 0.0: self.assertEqual(y.op.type, "StopGradient") self.assertIs(dydx, None) else: if scale == 1.0: self.assertEqual(y.op.type, "Identity") else: self.assertEqual(y.op.type, "ScaleGradient_float32") with self.test_session() as sess: dydx_, y_ = sess.run([dydx, y], feed_dict={x: [x_]}) self.assertAlmostEqual(dydx_[0], 2 * scale * x_, places=6) self.assertAlmostEqual(y_[0], x_**2, places=6)
def testOpScale(self, x_, scale): x = tf.placeholder(tf.float32, [1]) y = x * x y = snt.scale_gradient(y, scale) dydx = tf.gradients([y], [x])[0] if scale == 0.0: self.assertEqual(y.op.type, "StopGradient") self.assertIs(dydx, None) else: if scale == 1.0: self.assertEqual(y.op.type, "Identity") else: self.assertEqual(y.op.type, "ScaleGradient_float32") with self.test_session() as sess: dydx_, y_ = sess.run([dydx, y], feed_dict={x: [x_]}) self.assertAlmostEqual(dydx_[0], 2 * scale * x_, places=6) self.assertAlmostEqual(y_[0], x_ ** 2, places=6)
def testShape(self): x = tf.placeholder(tf.float32, [None, 10, 13]) y = snt.scale_gradient(x, 0.1) shape = tuple(y.get_shape().as_list()) self.assertEqual(shape, (None, 10, 13))