def test_tensorflow_with_constants(constants): eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check tensorflow with tf.Session(config=_TF_CONFIG).as_default(): res_got = expr(var, backend='tensorflow') assert all(array is None or infer_backend(array) == 'tensorflow' for array in expr._evaluated_constants['tensorflow']) assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check tensorflow call returns tensorflow still res_got3 = expr(backends.to_tensorflow(var)) assert isinstance(res_got3, tf.Tensor)
def test_torch_with_constants(constants): eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check torch res_got = expr(var, backend='torch') assert all(array is None or infer_backend(array) == 'torch' for array in expr._evaluated_constants['torch']) assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check torch call returns torch still res_got3 = expr(backends.to_torch(var)) assert isinstance(res_got3, torch.Tensor) res_got3 = res_got3.numpy( ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy() assert np.allclose(res_exp, res_got3)
def test_cupy_with_constants(constants): # pragma: no cover eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check cupy res_got = expr(var, backend='cupy') # check cupy versions of constants exist assert all(array is None or infer_backend(array) == 'cupy' for array in expr._evaluated_constants['cupy']) assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check cupy call returns cupy still res_got3 = expr(cupy.asarray(var)) assert isinstance(res_got3, cupy.ndarray) assert np.allclose(res_exp, res_got3.get())
def test_theano_with_constants(constants): eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check theano res_got = expr(var, backend='theano') assert all(array is None or infer_backend(array) == 'theano' for array in expr._evaluated_constants['theano']) assert np.allclose(res_exp, res_got) # check can call with numpy still res_got2 = expr(var, backend='numpy') assert np.allclose(res_exp, res_got2) # check theano call returns theano still res_got3 = expr(backends.to_theano(var)) assert isinstance(res_got3, theano.tensor.TensorVariable)
def test_jax_with_constants(constants): # pragma: no cover eq = 'ij,jk,kl->li' shapes = (2, 3), (3, 4), (4, 5) non_const, = {0, 1, 2} - constants ops = [ np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes) ] var = np.random.rand(*shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) # check jax res_got = expr(var, backend='jax') # check jax versions of constants exist assert all(array is None or infer_backend(array) == 'jax' for array in expr._evaluated_constants['jax']) assert np.allclose(res_exp, res_got)
def test_auto_backend_custom_array_no_tensordot(): x = Shaped((1, 2, 3)) # Shaped is an array-like object defined by opt_einsum - which has no TDOT assert infer_backend(x) == 'opt_einsum' assert parse_backend([x], 'auto') == 'numpy'