def test_setitem(self): # Single integer index. a = array_creation.array([1., 2., 3.]) b = array_creation.array(5.) c = array_creation.array(10.) tensors = [arr.data for arr in [a, b, c]] with tf.GradientTape() as g: g.watch(tensors) a[1] = b + c loss = math.sum(a) gradients = g.gradient(loss.data, tensors) self.assertSequenceEqual( array_creation.array(gradients[0]).tolist(), [1., 0., 1.]) self.assertEqual(array_creation.array(gradients[1]).tolist(), 1.) self.assertEqual(array_creation.array(gradients[2]).tolist(), 1.) # Tuple index. a = array_creation.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]]) # 2x2x2 array. b = array_creation.array([10., 11.]) tensors = [arr.data for arr in [a, b]] with tf.GradientTape() as g: g.watch(tensors) a[(1, 0)] = b loss = math.sum(a) gradients = g.gradient(loss.data, tensors) self.assertSequenceEqual( array_creation.array(gradients[0]).tolist(), [[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]]) self.assertEqual(array_creation.array(gradients[1]).tolist(), [1., 1.])
def f(a, b): return math.sum(math.sqrt(math.exp(a)) + b)