示例#1
0
    def test_mean_with_grad(self):
        """ Test mean of all elements of tensor, and input data to backward function
        """
        tensor1 = Tensor([1, 2, 3], requires_grad=True)
        tensor2 = tensor1.mean()

        tensor2.backward(Tensor(3.))

        assert tensor1.grad.data.tolist() == [1, 1, 1]
示例#2
0
    def test_simple_mean(self):
        """ Test mean of all elements of tensor
        """
        tensor1 = Tensor([3, 6, 9], requires_grad=True)

        tensor2 = tensor1.mean()
        tensor2.backward()

        assert tensor2.data.tolist() == 6
        assert np.round(tensor1.grad.data, 2).tolist() == [0.33, 0.33, 0.33]
示例#3
0
    def test_mean_with_grad_torch(self):
        """ Test mean of all elements of tensor, and input data to backward function
            and compare with torch
        """
        tensor1 = Tensor([1, 2, 3], requires_grad=True)
        tensor2 = tensor1.mean()

        tensor_torch1 = tensor([1, 2, 3], dtype=float, requires_grad=True)
        tensor_torch2 = tensor_torch1.mean()

        tensor2.backward(Tensor(3.))
        tensor_torch2.backward(tensor(3.))

        assert tensor1.grad.data.tolist() == tensor1.grad.data.tolist()
示例#4
0
    def test_simple_mean_torch(self):
        """ Test mean of all elements of tensor and compare with torch
        """
        tensor1 = Tensor([3, 6, 9], requires_grad=True)
        tensor2 = tensor1.mean()

        tensor_torch1 = tensor([3, 6, 9], dtype=float, requires_grad=True)
        tensor_torch2 = tensor_torch1.mean()

        tensor2.backward()
        tensor_torch2.backward()

        assert tensor2.data.tolist() == tensor_torch2.data.tolist()
        assert tensor1.grad.data.tolist() == tensor_torch1.grad.data.tolist()