def test_torch_with_constants(constants): eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check torch res_got = expr(var, backend='torch') assert all(array is None or infer_backend(array) == 'torch' for array in expr._evaluated_constants['torch']) assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check torch call returns torch still res_got3 = expr(backends.to_torch(var)) assert isinstance(res_got3, torch.Tensor) res_got3 = res_got3.numpy( ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy() assert np.allclose(res_exp, res_got3)
def test_torch_with_constants(): eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) constants = {0, 2} ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[1]) res_exp = contract(eq, ops[0], var, ops[2]) expr = contract_expression(eq, *ops, constants=constants) # check torch res_got = expr(var, backend='torch') assert 'torch' in expr._evaluated_constants assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check torch call returns torch still res_got3 = expr(backends.to_torch(var), backend='torch') assert isinstance(res_got3, torch.Tensor) res_got3 = res_got3.numpy( ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy() assert np.allclose(res_exp, res_got3)
def test_torch(string): views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] expr = contract_expression(string, *shps, optimize=True) opt = expr(*views, backend='torch') assert np.allclose(ein, opt) # test non-conversion mode torch_views = [backends.to_torch(view) for view in views] torch_opt = expr(*torch_views) assert isinstance(torch_opt, torch.Tensor) assert np.allclose(ein, torch_opt.cpu().numpy())