Example #1
0
    def test_function_factory(self):
        # 1d
        diag_var1 = Variable(diag, requires_grad=True)
        diag_var2 = Variable(diag, requires_grad=True)
        test_mat = torch.Tensor([3, 4, 5])

        diag_lv = DiagLazyVariable(diag_var1)
        diag_ev = DiagLazyVariable(diag_var2).evaluate()

        # Forward
        res = diag_lv.matmul(Variable(test_mat))
        actual = torch.matmul(diag_ev, Variable(test_mat))
        self.assertLess(torch.norm(res.data - actual.data), 1e-4)

        # Backward
        res.sum().backward()
        actual.sum().backward()
        self.assertLess(
            torch.norm(diag_var1.grad.data - diag_var2.grad.data),
            1e-3,
        )

        # 2d
        diag_var1 = Variable(diag, requires_grad=True)
        diag_var2 = Variable(diag, requires_grad=True)
        test_mat = torch.eye(3)

        diag_lv = DiagLazyVariable(diag_var1)
        diag_ev = DiagLazyVariable(diag_var2).evaluate()

        # Forward
        res = diag_lv.matmul(Variable(test_mat))
        actual = torch.matmul(diag_ev, Variable(test_mat))
        self.assertLess(torch.norm(res.data - actual.data), 1e-4)

        # Backward
        res.sum().backward()
        actual.sum().backward()
        self.assertLess(
            torch.norm(diag_var1.grad.data - diag_var2.grad.data),
            1e-3,
        )
def test_batch_function_factory():
    # 2d
    diag_var1 = Variable(diag.repeat(5, 1), requires_grad=True)
    diag_var2 = Variable(diag.repeat(5, 1), requires_grad=True)
    test_mat = torch.eye(3).repeat(5, 1, 1)

    diag_lv = DiagLazyVariable(diag_var1)
    diag_ev = DiagLazyVariable(diag_var2).evaluate()

    # Forward
    res = diag_lv.matmul(Variable(test_mat))
    actual = torch.matmul(diag_ev, Variable(test_mat))
    assert torch.norm(res.data - actual.data) < 1e-4

    # Backward
    res.sum().backward()
    actual.sum().backward()
    assert torch.norm(diag_var1.grad.data - diag_var2.grad.data) < 1e-3