def test_right_rotate_trace(): a = torch.randn(5, 5) for i in range(1, 5): actual = circulant.rotate(a, -i).trace() result = circulant.left_rotate_trace(a, -i) assert(math.fabs(actual - result) < 1e-5)
def test_left_rotate_trace(self): a = torch.randn(5, 5) for i in range(1, 5): actual = circulant.rotate(a, i).trace() result = circulant.left_rotate_trace(a, i) self.assertLess(math.fabs(actual - result), 1e-5)
def test_rotate_vector_forward(): a = torch.randn(5) Q0 = torch.zeros(5, 5) Q0[0, 4] = 1 Q0[1:, :-1] = torch.eye(4) Q = Q0.clone() for i in range(1, 5): a_rotated_result = circulant.rotate(a, i) a_rotated_actual = Q.matmul(a) assert(utils.approx_equal(a_rotated_actual, a_rotated_result)) Q = Q.matmul(Q0)
def test_rotate_matrix_reverse(self): a = torch.randn(5, 5) Q0 = torch.zeros(5, 5) Q0[0, 4] = 1 Q0[1:, :-1] = torch.eye(4) Q = Q0.clone() for i in range(1, 5): a_rotated_result = circulant.rotate(a, -i) a_rotated_actual = Q.inverse().matmul(a) self.assertTrue( utils.approx_equal(a_rotated_actual, a_rotated_result)) Q = Q.matmul(Q0)