Exemplo n.º 1
0
def test_mul_adding_constant_mul():
    mat1 = make_random_mat(20, rank=4, batch_size=5)
    mat2 = make_random_mat(20, rank=4, batch_size=5)
    mat3 = make_random_mat(20, rank=4, batch_size=5)
    const = Variable(torch.ones(1), requires_grad=True)

    mat1_copy = Variable(mat1.data, requires_grad=True)
    mat2_copy = Variable(mat2.data, requires_grad=True)
    mat3_copy = Variable(mat3.data, requires_grad=True)
    const_copy = Variable(const.data, requires_grad=True)

    # Forward
    res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2),
                          RootLazyVariable(mat3))
    res = res * const
    actual = prod([
        mat1_copy.matmul(mat1_copy.transpose(-1, -2)),
        mat2_copy.matmul(mat2_copy.transpose(-1, -2)),
        mat3_copy.matmul(mat3_copy.transpose(-1, -2)),
    ]) * const_copy
    assert torch.max(
        ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01

    # Forward
    res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2),
                          RootLazyVariable(mat3))
    res = res * 2.5
    actual = prod([
        mat1_copy.matmul(mat1_copy.transpose(-1, -2)),
        mat2_copy.matmul(mat2_copy.transpose(-1, -2)),
        mat3_copy.matmul(mat3_copy.transpose(-1, -2)),
    ]) * 2.5
    assert torch.max(
        ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01
Exemplo n.º 2
0
def test_mul_adding_another_variable():
    mat1 = make_random_mat(20, rank=4, batch_size=5)
    mat2 = make_random_mat(20, rank=4, batch_size=5)
    mat3 = make_random_mat(20, rank=4, batch_size=5)

    mat1_copy = Variable(mat1.data, requires_grad=True)
    mat2_copy = Variable(mat2.data, requires_grad=True)
    mat3_copy = Variable(mat3.data, requires_grad=True)

    # Forward
    res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2))
    res = res * RootLazyVariable(mat3)
    actual = prod([
        mat1_copy.matmul(mat1_copy.transpose(-1, -2)),
        mat2_copy.matmul(mat2_copy.transpose(-1, -2)),
        mat3_copy.matmul(mat3_copy.transpose(-1, -2)),
    ])
    assert torch.max(
        ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01