def test_tall(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.randn(shape) py_y = py_rand.randn(shape) torch_x = torch_matrix.float_tensor(py_x) torch_y = torch_matrix.float_tensor(py_y) py_res = py_matrix.lesser_equal(py_x, py_y) torch_res = torch_matrix.lesser_equal(torch_x, torch_y) # overall py_all = py_matrix.tall(py_res) torch_all = torch_matrix.tall(torch_res) assert py_all == torch_all, \ "python tall != torch tall: overall" # axis = 0 py_all = py_matrix.tall(py_res, axis=0) torch_all = torch_matrix.tall(torch_res, axis=0) py_torch_all = torch_matrix.to_numpy_array(torch_all) assert py_matrix.allclose(py_all, py_torch_all), \ "python tall != torch tall: (axis-0)" # axis = 1 py_all = py_matrix.tall(py_res, axis=1) torch_all = torch_matrix.tall(torch_res, axis=1) py_torch_all = torch_matrix.to_numpy_array(torch_all) assert py_matrix.allclose(py_all, py_torch_all), \ "python tall != torch tall: (axis-1)" # axis = 0, keepdims py_all = py_matrix.tall(py_res, axis=0, keepdims=True) torch_all = torch_matrix.tall(torch_res, axis=0, keepdims=True) py_torch_all = torch_matrix.to_numpy_array(torch_all) assert py_matrix.allclose(py_all, py_torch_all), \ "python tall != torch tall: (axis-0, keepdims)" # axis = 1, keepdims py_all = py_matrix.tall(py_res, axis=1, keepdims=True) torch_all = torch_matrix.tall(torch_res, axis=1, keepdims=True) py_torch_all = torch_matrix.to_numpy_array(torch_all) assert py_matrix.allclose(py_all, py_torch_all), \ "python tall != torch tall: (axis-1, keepdim)"
def test_argmin(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.randn(shape) torch_x = torch_matrix.float_tensor(py_x) # axis=0 py_res = py_matrix.argmin(py_x, axis=0) torch_res = torch_matrix.argmin(torch_x, axis=0) py_torch_res = torch_matrix.to_numpy_array(torch_res) assert py_matrix.allclose(py_res, py_torch_res), \ "python argmin != torch argmin: (axis-0)" # axis=1 py_res = py_matrix.argmin(py_x, axis=1) torch_res = torch_matrix.argmin(torch_x, axis=1) py_torch_res = torch_matrix.to_numpy_array(torch_res) assert py_matrix.allclose(py_res, py_torch_res), \ "python argmin != torch argmin: (axis-1)"
def test_greater(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.randn(shape) py_y = py_rand.randn(shape) torch_x = torch_matrix.float_tensor(py_x) torch_y = torch_matrix.float_tensor(py_y) py_res = py_matrix.greater(py_x, py_y) torch_res = torch_matrix.greater(torch_x, torch_y) py_torch_res = torch_matrix.to_numpy_array(torch_res) assert py_matrix.allclose(py_res, py_torch_res), \ "python greater != torch greater"
def test_lesser_equal(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.randn(shape) py_y = py_rand.randn(shape) torch_x = torch_matrix.float_tensor(py_x) torch_y = torch_matrix.float_tensor(py_y) py_res = py_matrix.lesser_equal(py_x, py_y) torch_res = torch_matrix.lesser_equal(torch_x, torch_y) py_torch_res = torch_matrix.to_numpy_array(torch_res) assert py_matrix.allclose(py_res, py_torch_res), \ "python lesser_equal != torch lesser_equal"
def test_not_equal(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.randn(shape) py_y = py_rand.randn(shape) torch_x = torch_matrix.float_tensor(py_x) torch_y = torch_matrix.float_tensor(py_y) py_neq = py_matrix.not_equal(py_x, py_y) torch_neq = torch_matrix.not_equal(torch_x, torch_y) py_torch_neq = torch_matrix.to_numpy_array(torch_neq) assert py_matrix.allclose(py_neq, py_torch_neq), \ "python not equal != torch not equal"
def test_diag(): shape = (100,) py_rand.set_seed() py_vec = py_rand.randn(shape) py_mat = py_matrix.diagonal_matrix(py_vec) py_diag = py_matrix.diag(py_mat) assert py_matrix.allclose(py_vec, py_diag), \ "python vec -> matrix -> vec failure: diag" torch_vec = torch_rand.randn(shape) torch_mat = torch_matrix.diagonal_matrix(torch_vec) torch_diag = torch_matrix.diag(torch_mat) assert torch_matrix.allclose(torch_vec, torch_diag), \ "torch vec -> matrix -> vec failure: diag"
def test_conversion(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.rand(shape) torch_x = torch_matrix.float_tensor(py_x) py_torch_x = torch_matrix.to_numpy_array(torch_x) assert py_matrix.allclose(py_x, py_torch_x), \ "python -> torch -> python failure" torch_rand.set_seed() torch_y = torch_rand.rand(shape) py_y = torch_matrix.to_numpy_array(torch_y) torch_py_y = torch_matrix.float_tensor(py_y) assert torch_matrix.allclose(torch_y, torch_py_y), \ "torch -> python -> torch failure"
def assert_close(pymat, torchmat, name, rtol=1e-05, atol=1e-06): pytorchmat = torch_matrix.to_numpy_array(torchmat) torchpymat = torch_matrix.float_tensor(pymat) py_vs_torch = py_matrix.allclose(pymat, pytorchmat, rtol=rtol, atol=atol) torch_vs_py = torch_matrix.allclose(torchmat, torchpymat, rtol=rtol, atol=atol) if py_vs_torch and torch_vs_py: return if py_vs_torch and not torch_vs_py: assert False,\ "{}: failure at torch allclose".format(name) elif not py_vs_torch and torch_vs_py: assert False, \ "{}: failure at python allclose".format(name) else: assert False, \ "{}: failure at both python and torch allclose".format(name)
def test_fill_diagonal(): n = 10 py_mat = py_matrix.identity(n) torch_mat = torch_matrix.identity(n) fill_value = 2.0 py_mult = fill_value * py_mat py_matrix.fill_diagonal(py_mat, fill_value) assert py_matrix.allclose(py_mat, py_mult), \ "python fill != python multiplly for diagonal matrix" torch_mult = fill_value * torch_mat torch_matrix.fill_diagonal(torch_mat, fill_value) assert torch_matrix.allclose(torch_mat, torch_mult), \ "torch fill != python multiplly for diagonal matrix" assert_close(py_mat, torch_mat, "fill_diagonal")
def test_transpose(): shape = (100, 100) py_rand.set_seed() py_x = py_rand.rand(shape) torch_x = torch_matrix.float_tensor(py_x) py_x_T = py_matrix.transpose(py_x) py_torch_x_T = torch_matrix.to_numpy_array(torch_matrix.transpose(torch_x)) assert py_matrix.allclose(py_x_T, py_torch_x_T), \ "python -> torch -> python failure: transpose" torch_rand.set_seed() torch_y = torch_rand.rand(shape) py_y = torch_matrix.to_numpy_array(torch_y) torch_y_T = torch_matrix.transpose(torch_y) torch_py_y_T = torch_matrix.float_tensor(py_matrix.transpose(py_y)) assert torch_matrix.allclose(torch_y_T, torch_py_y_T), \ "torch -> python -> torch failure: transpose"