def test_torch(self): """Test that a torch tensor is differentiable when using scatter addition""" x = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], requires_grad=True) y = torch.tensor(0.56, requires_grad=True) res = fn.scatter_element_add(x, [1, 2], y ** 2) loss = res[1, 2] assert isinstance(res, torch.Tensor) assert fn.allclose(res.detach(), onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]])) loss.backward() assert fn.allclose(x.grad, onp.array([[0, 0, 0], [0, 0, 1.0]])) assert fn.allclose(y.grad, 2 * y)
def test_tensorflow(self): """Test that a TF tensor is differentiable when using scatter addition""" x = tf.Variable([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) y = tf.Variable(0.56) with tf.GradientTape() as tape: res = fn.scatter_element_add(x, [1, 2], y ** 2) loss = res[1, 2] assert isinstance(res, tf.Tensor) assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]])) grad = tape.gradient(loss, [x, y]) assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.0]])) assert fn.allclose(grad[1], 2 * y)
def cost(weights): return fn.scatter_element_add(weights[0], [1, 2], weights[1]**2)