def test_mul_adding_constant_mul(): mat1 = make_random_mat(20, rank=4, batch_size=5) mat2 = make_random_mat(20, rank=4, batch_size=5) mat3 = make_random_mat(20, rank=4, batch_size=5) const = Variable(torch.ones(1), requires_grad=True) mat1_copy = Variable(mat1.data, requires_grad=True) mat2_copy = Variable(mat2.data, requires_grad=True) mat3_copy = Variable(mat3.data, requires_grad=True) const_copy = Variable(const.data, requires_grad=True) # Forward res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2), RootLazyVariable(mat3)) res = res * const actual = prod([ mat1_copy.matmul(mat1_copy.transpose(-1, -2)), mat2_copy.matmul(mat2_copy.transpose(-1, -2)), mat3_copy.matmul(mat3_copy.transpose(-1, -2)), ]) * const_copy assert torch.max( ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01 # Forward res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2), RootLazyVariable(mat3)) res = res * 2.5 actual = prod([ mat1_copy.matmul(mat1_copy.transpose(-1, -2)), mat2_copy.matmul(mat2_copy.transpose(-1, -2)), mat3_copy.matmul(mat3_copy.transpose(-1, -2)), ]) * 2.5 assert torch.max( ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01
def test_mul_adding_another_variable(): mat1 = make_random_mat(20, rank=4, batch_size=5) mat2 = make_random_mat(20, rank=4, batch_size=5) mat3 = make_random_mat(20, rank=4, batch_size=5) mat1_copy = Variable(mat1.data, requires_grad=True) mat2_copy = Variable(mat2.data, requires_grad=True) mat3_copy = Variable(mat3.data, requires_grad=True) # Forward res = MulLazyVariable(RootLazyVariable(mat1), RootLazyVariable(mat2)) res = res * RootLazyVariable(mat3) actual = prod([ mat1_copy.matmul(mat1_copy.transpose(-1, -2)), mat2_copy.matmul(mat2_copy.transpose(-1, -2)), mat3_copy.matmul(mat3_copy.transpose(-1, -2)), ]) assert torch.max( ((res.evaluate().data - actual.data) / actual.data).abs()) < 0.01