def test_matmul_batch_mat(self):
        avar = Variable(a.repeat(3, 1, 1), requires_grad=True)
        bvar = Variable(b.repeat(3, 1, 1), requires_grad=True)
        cvar = Variable(c.repeat(3, 1, 1), requires_grad=True)
        mat = Variable(torch.randn(3, 24, 5), requires_grad=True)
        kp_lazy_var = KroneckerProductLazyVariable(
            NonLazyVariable(avar),
            NonLazyVariable(bvar),
            NonLazyVariable(cvar),
        )
        res = kp_lazy_var.matmul(mat)

        avar_copy = Variable(a.repeat(3, 1, 1), requires_grad=True)
        bvar_copy = Variable(b.repeat(3, 1, 1), requires_grad=True)
        cvar_copy = Variable(c.repeat(3, 1, 1), requires_grad=True)
        mat_copy = Variable(mat.data.clone(), requires_grad=True)
        actual = kron(kron(avar_copy, bvar_copy), cvar_copy).matmul(mat_copy)
        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()
        self.assertTrue(approx_equal(avar_copy.grad.data, avar.grad.data))
        self.assertTrue(approx_equal(bvar_copy.grad.data, bvar.grad.data))
        self.assertTrue(approx_equal(cvar_copy.grad.data, cvar.grad.data))
        self.assertTrue(approx_equal(mat_copy.grad.data, mat.grad.data))
    def test_matmul_vec(self):
        avar = Variable(a, requires_grad=True)
        bvar = Variable(b, requires_grad=True)
        cvar = Variable(c, requires_grad=True)
        vec = Variable(torch.randn(24), requires_grad=True)
        kp_lazy_var = KroneckerProductLazyVariable(
            NonLazyVariable(avar),
            NonLazyVariable(bvar),
            NonLazyVariable(cvar),
        )
        res = kp_lazy_var.matmul(vec)

        avar_copy = Variable(a, requires_grad=True)
        bvar_copy = Variable(b, requires_grad=True)
        cvar_copy = Variable(c, requires_grad=True)
        vec_copy = Variable(vec.data.clone(), requires_grad=True)
        actual = kron(kron(avar_copy, bvar_copy), cvar_copy).matmul(vec_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()
        self.assertTrue(approx_equal(avar_copy.grad.data, avar.grad.data))
        self.assertTrue(approx_equal(bvar_copy.grad.data, bvar.grad.data))
        self.assertTrue(approx_equal(cvar_copy.grad.data, cvar.grad.data))
        self.assertTrue(approx_equal(vec_copy.grad.data, vec.grad.data))
Пример #3
0
    def test_matmul_mat_random_rectangular(self):
        a = torch.randn(4, 2, 3)
        b = torch.randn(4, 5, 2)
        c = torch.randn(4, 6, 4)
        rhs = torch.randn(4, 3 * 2 * 4, 2)
        a_copy = torch.tensor(a)
        b_copy = b.clone()
        c_copy = c.clone()
        rhs_copy = rhs.clone()

        a.requires_grad = True
        b.requires_grad = True
        c.requires_grad = True
        a_copy.requires_grad = True
        b_copy.requires_grad = True
        c_copy.requires_grad = True
        rhs.requires_grad = True
        rhs_copy.requires_grad = True

        actual = kron(kron(a_copy, b_copy), c_copy).matmul(rhs_copy)
        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(a),
                                                   NonLazyVariable(b),
                                                   NonLazyVariable(c))
        res = kp_lazy_var.matmul(rhs)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()
        self.assertTrue(approx_equal(a_copy.grad.data, a.grad.data))
        self.assertTrue(approx_equal(b_copy.grad.data, b.grad.data))
        self.assertTrue(approx_equal(c_copy.grad.data, c.grad.data))
        self.assertTrue(approx_equal(rhs_copy.grad.data, rhs.grad.data))
def test_matmul_mat():
    avar = Variable(a, requires_grad=True)
    bvar = Variable(b, requires_grad=True)
    cvar = Variable(c, requires_grad=True)
    mat = Variable(torch.randn(24, 5), requires_grad=True)
    kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(avar),
                                               NonLazyVariable(bvar),
                                               NonLazyVariable(cvar))
    res = kp_lazy_var.matmul(mat)

    avar_copy = Variable(a, requires_grad=True)
    bvar_copy = Variable(b, requires_grad=True)
    cvar_copy = Variable(c, requires_grad=True)
    mat_copy = Variable(mat.data.clone(), requires_grad=True)
    actual = kron(kron(avar_copy, bvar_copy), cvar_copy).matmul(mat_copy)
    assert approx_equal(res.data, actual.data)

    actual.sum().backward()
    res.sum().backward()
    assert approx_equal(avar_copy.grad.data, avar.grad.data)
    assert approx_equal(bvar_copy.grad.data, bvar.grad.data)
    assert approx_equal(cvar_copy.grad.data, cvar.grad.data)
    assert approx_equal(mat_copy.grad.data, mat.grad.data)
Пример #5
0
    def test_matmul_vec_random_rectangular(self):
        ax = torch.randn(4, 2, 3)
        bx = torch.randn(4, 5, 2)
        cx = torch.randn(4, 6, 4)
        rhsx = torch.randn(4, 3 * 2 * 4, 1)
        rhsx = rhsx / torch.norm(rhsx)
        ax_copy = Variable(ax, requires_grad=True)
        bx_copy = bx.clone()
        cx_copy = cx.clone()
        rhsx_copy = rhsx.clone()

        ax.requires_grad = True
        bx.requires_grad = True
        cx.requires_grad = True
        ax_copy.requires_grad = True
        bx_copy.requires_grad = True
        cx_copy.requires_grad = True
        rhsx.requires_grad = True
        rhsx_copy.requires_grad = True

        kp_lazy_var = KroneckerProductLazyVariable(NonLazyVariable(ax),
                                                   NonLazyVariable(bx),
                                                   NonLazyVariable(cx))
        res = kp_lazy_var.matmul(rhsx)

        actual_mat = kron(kron(ax_copy, bx_copy), cx_copy)
        actual = actual_mat.matmul(rhsx_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()
        self.assertTrue(approx_equal(ax_copy.grad.data, ax.grad.data))
        self.assertTrue(approx_equal(bx_copy.grad.data, bx.grad.data))
        self.assertTrue(approx_equal(cx_copy.grad.data, cx.grad.data))
        self.assertTrue(approx_equal(rhsx_copy.grad.data, rhsx.grad.data))