Ejemplo n.º 1
0
    def test_torch(self):
        """Test that a torch tensor is differentiable when using scatter addition"""
        x = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], requires_grad=True)
        y = torch.tensor(0.56, requires_grad=True)

        res = fn.scatter_element_add(x, [1, 2], y ** 2)
        loss = res[1, 2]

        assert isinstance(res, torch.Tensor)
        assert fn.allclose(res.detach(), onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        loss.backward()
        assert fn.allclose(x.grad, onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(y.grad, 2 * y)
Ejemplo n.º 2
0
    def test_tensorflow(self):
        """Test that a TF tensor is differentiable when using scatter addition"""
        x = tf.Variable([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
        y = tf.Variable(0.56)

        with tf.GradientTape() as tape:
            res = fn.scatter_element_add(x, [1, 2], y ** 2)
            loss = res[1, 2]

        assert isinstance(res, tf.Tensor)
        assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        grad = tape.gradient(loss, [x, y])
        assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(grad[1], 2 * y)
Ejemplo n.º 3
0
 def cost(weights):
     return fn.scatter_element_add(weights[0], [1, 2], weights[1]**2)