def testCustomGrad(self): def fn(a, b, c): return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) def grad_fn(inputs, trainable_variables, unused_outputs, unused_grad_outputs): grad_inputs = [ array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) ] grad_vars = [ array_ops.ones_like(t) * (i + len(inputs) + 1.) for i, t in enumerate(trainable_variables) ] return grad_inputs, grad_vars a = random_ops.random_uniform([11, 6]) b = random_ops.random_uniform([11, 7]) c = random_ops.random_uniform([7, 10]) w = random_ops.random_uniform([6, 10]) out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) loss = math_ops.reduce_mean(out) grads = gradients_impl.gradients( loss, [a, b, c, variables.trainable_variables()[0]]) expected_grads = [ array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) ] with self.test_session() as sess: sess.run(variables.global_variables_initializer()) g_val, eg_val = sess.run([grads, expected_grads]) for g1, g2 in zip(g_val, eg_val): self.assertAllClose(g1, g2)
def testCustomGrad(self): def fn(a, b, c): return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul( b, c) def grad_fn(inputs, trainable_variables, unused_outputs, unused_grad_outputs): grad_inputs = [ array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) ] grad_vars = [ array_ops.ones_like(t) * (i + len(inputs) + 1.) for i, t in enumerate(trainable_variables) ] return grad_inputs, grad_vars a = random_ops.random_uniform([11, 6]) b = random_ops.random_uniform([11, 7]) c = random_ops.random_uniform([7, 10]) w = random_ops.random_uniform([6, 10]) out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) loss = math_ops.reduce_mean(out) grads = gradients_impl.gradients( loss, [a, b, c, variables.trainable_variables()[0]]) expected_grads = [ array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) ] with self.test_session() as sess: sess.run(variables.global_variables_initializer()) g_val, eg_val = sess.run([grads, expected_grads]) for g1, g2 in zip(g_val, eg_val): self.assertAllClose(g1, g2)
def testCorrectness(self): w = random_ops.random_uniform([6, 10]) def fn(a, b, c): return core_layers.dense(a, 10, use_bias=False, kernel_initializer=lambda shape, dtype, partition_info: w) + math_ops.matmul( b, c) def grad_fn(inputs, trainable_variables, outputs, grad_outputs): outputs = outputs[0] grad_outputs = grad_outputs[0] grad_inputs = gradients_impl.gradients(outputs, inputs, grad_ys=grad_outputs) grad_vars = gradients_impl.gradients(outputs, trainable_variables, grad_ys=grad_outputs) return grad_inputs, grad_vars custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) a = random_ops.random_uniform([11, 6]) b = random_ops.random_uniform([11, 7]) c = random_ops.random_uniform([7, 10]) out = fn(a, b, c) custom_out = custom_fn(a, b, c) self.assertEqual(out.get_shape().as_list(), custom_out.get_shape().as_list()) loss = math_ops.reduce_mean(out) custom_loss = math_ops.reduce_mean(custom_out) grads = gradients_impl.gradients(loss, [a, b, c] + [variables.trainable_variables()[0]]) custom_grads = gradients_impl.gradients( custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) out_val, custom_out_val, grads_val, custom_grads_val = sess.run( [out, custom_out, grads, custom_grads]) self.assertAllClose(out_val, custom_out_val) for g1, g2 in zip(grads_val, custom_grads_val): self.assertAllClose(g1, g2)
def testCorrectness(self): w = random_ops.random_uniform([6, 10]) def fn(a, b, c): return core_layers.dense( a, 10, use_bias=False, kernel_initializer=lambda shape, dtype, partition_info: w ) + math_ops.matmul(b, c) def grad_fn(inputs, trainable_variables, outputs, grad_outputs): outputs = outputs[0] grad_outputs = grad_outputs[0] grad_inputs = gradients_impl.gradients( outputs, inputs, grad_ys=grad_outputs) grad_vars = gradients_impl.gradients( outputs, trainable_variables, grad_ys=grad_outputs) return grad_inputs, grad_vars custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) a = random_ops.random_uniform([11, 6]) b = random_ops.random_uniform([11, 7]) c = random_ops.random_uniform([7, 10]) out = fn(a, b, c) custom_out = custom_fn(a, b, c) self.assertEqual(out.get_shape().as_list(), custom_out.get_shape().as_list()) loss = math_ops.reduce_mean(out) custom_loss = math_ops.reduce_mean(custom_out) grads = gradients_impl.gradients( loss, [a, b, c] + [variables.trainable_variables()[0]]) custom_grads = gradients_impl.gradients( custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) out_val, custom_out_val, grads_val, custom_grads_val = sess.run( [out, custom_out, grads, custom_grads]) self.assertAllClose(out_val, custom_out_val) for g1, g2 in zip(grads_val, custom_grads_val): self.assertAllClose(g1, g2)