def test_circulant_transpose(): a = torch.randn(5) C = circulant.circulant(a) C_T_actual = C.t() C_T_result = circulant.circulant(circulant.circulant_transpose(a)) assert(utils.approx_equal(C_T_actual, C_T_result))
def test_frobenius_circulant_approximation(): A = torch.randn(5, 5) C1 = circulant.frobenius_circulant_approximation(A) C2 = circulant.frobenius_circulant_approximation(circulant.circulant(C1)) assert(utils.approx_equal(C1, C2))
def test_circulant_matmul(): a = torch.randn(5) M = torch.randn(5, 5) aM_result = circulant.circulant_matmul(a, M) C = circulant.circulant(a) aM_actual = C.mm(M) assert(utils.approx_equal(aM_result, aM_actual))
def test_circulant_inv_matmul(self): a = torch.randn(5) M = torch.randn(5, 5) aM_result = circulant.circulant_inv_matmul(a, M) C = circulant.circulant(a) aM_actual = C.inverse().mm(M) self.assertTrue(utils.approx_equal(aM_result, aM_actual))