def test_batch_matmul():
    rhs = torch.randn(2, 4 * 4, 4)
    rhs_var = Variable(rhs, requires_grad=True)
    rhs_var_copy = Variable(rhs, requires_grad=True)

    block_var = Variable(blocks, requires_grad=True)
    block_var_copy = Variable(blocks, requires_grad=True)

    actual_block_diagonal = Variable(torch.zeros(2, 16, 16))
    for i in range(2):
        for j in range(4):
            actual_block_diagonal[i, j * 4:(j + 1) * 4, j * 4:(j + 1) *
                                  4] = block_var_copy[i * 4 + j]

    res = BlockDiagonalLazyVariable(NonLazyVariable(block_var),
                                    n_blocks=4).matmul(rhs_var)
    actual = actual_block_diagonal.matmul(rhs_var_copy)

    assert approx_equal(res.data, actual.data)

    actual.sum().backward()
    res.sum().backward()

    assert approx_equal(rhs_var.grad.data, rhs_var_copy.grad.data)
    assert approx_equal(block_var.grad.data, block_var_copy.grad.data)
示例#2
0
    def test_matmul(self):
        rhs = torch.randn(4 * 8, 4)
        rhs_var = Variable(rhs, requires_grad=True)
        rhs_var_copy = Variable(rhs, requires_grad=True)

        block_var = Variable(blocks, requires_grad=True)
        block_var_copy = Variable(blocks, requires_grad=True)

        actual_block_diagonal = Variable(torch.zeros(32, 32))
        for i in range(8):
            actual_block_diagonal[i * 4:(i + 1) * 4,
                                  i * 4:(i + 1) * 4] = block_var_copy[i]

        res = BlockDiagonalLazyVariable(
            NonLazyVariable(block_var)).matmul(rhs_var)
        actual = actual_block_diagonal.matmul(rhs_var_copy)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()
        res.sum().backward()

        self.assertTrue(approx_equal(rhs_var.grad.data,
                                     rhs_var_copy.grad.data))
        self.assertTrue(
            approx_equal(block_var.grad.data, block_var_copy.grad.data))
示例#3
0
    def test_getitem_batch(self):
        block_var = Variable(blocks, requires_grad=True)
        actual_block_diagonal = Variable(torch.zeros(2, 16, 16))
        for i in range(2):
            for j in range(4):
                actual_block_diagonal[i, j * 4:(j + 1) * 4,
                                      j * 4:(j + 1) * 4] = block_var[i * 4 + j]

        res = BlockDiagonalLazyVariable(
            NonLazyVariable(block_var),
            n_blocks=4,
        )[0].evaluate()
        actual = actual_block_diagonal[0]
        self.assertTrue(approx_equal(actual.data, res.data))

        res = BlockDiagonalLazyVariable(
            NonLazyVariable(block_var),
            n_blocks=4,
        )[0, :5].evaluate()
        actual = actual_block_diagonal[0, :5]
        self.assertTrue(approx_equal(actual.data, res.data))

        res = BlockDiagonalLazyVariable(
            NonLazyVariable(block_var),
            n_blocks=4,
        )[1:, :5, 2]
        actual = actual_block_diagonal[1:, :5, 2]
        self.assertTrue(approx_equal(actual.data, res.data))
示例#4
0
    def test_getitem(self):
        block_var = Variable(blocks, requires_grad=True)
        actual_block_diagonal = Variable(torch.zeros(32, 32))
        for i in range(8):
            actual_block_diagonal[i * 4 : (i + 1) * 4, i * 4 : (i + 1) * 4] = block_var[i]

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var))[:5, 2]
        actual = actual_block_diagonal[:5, 2]
        self.assertTrue(approx_equal(actual.data, res.data))
def test_diag():
    block_var = Variable(blocks, requires_grad=True)
    actual_block_diagonal = Variable(torch.zeros(32, 32))
    for i in range(8):
        actual_block_diagonal[i * 4:(i + 1) * 4,
                              i * 4:(i + 1) * 4] = block_var[i]

    res = BlockDiagonalLazyVariable(NonLazyVariable(block_var)).diag()
    actual = actual_block_diagonal.diag()
    assert approx_equal(actual.data, res.data)
示例#6
0
    def test_batch_diag(self):
        block_var = Variable(blocks, requires_grad=True)
        actual_block_diagonal = Variable(torch.zeros(2, 16, 16))
        for i in range(2):
            for j in range(4):
                actual_block_diagonal[i, j * 4 : (j + 1) * 4, j * 4 : (j + 1) * 4] = block_var[i * 4 + j]

        res = BlockDiagonalLazyVariable(NonLazyVariable(block_var), n_blocks=4).diag()
        actual = torch.cat([actual_block_diagonal[0].diag().unsqueeze(0), actual_block_diagonal[1].diag().unsqueeze(0)])
        self.assertTrue(approx_equal(actual.data, res.data))