def test_tensorflow_with_constants(constants):
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check tensorflow
    with tf.Session(config=_TF_CONFIG).as_default():
        res_got = expr(var, backend='tensorflow')
    assert all(array is None or infer_backend(array) == 'tensorflow'
               for array in expr._evaluated_constants['tensorflow'])
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check tensorflow call returns tensorflow still
    res_got3 = expr(backends.to_tensorflow(var))
    assert isinstance(res_got3, tf.Tensor)
def test_torch_with_constants(constants):
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check torch
    res_got = expr(var, backend='torch')
    assert all(array is None or infer_backend(array) == 'torch'
               for array in expr._evaluated_constants['torch'])
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check torch call returns torch still
    res_got3 = expr(backends.to_torch(var))
    assert isinstance(res_got3, torch.Tensor)
    res_got3 = res_got3.numpy(
    ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy()
    assert np.allclose(res_exp, res_got3)
def test_cupy_with_constants(constants):  # pragma: no cover
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check cupy
    res_got = expr(var, backend='cupy')
    # check cupy versions of constants exist
    assert all(array is None or infer_backend(array) == 'cupy'
               for array in expr._evaluated_constants['cupy'])
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check cupy call returns cupy still
    res_got3 = expr(cupy.asarray(var))
    assert isinstance(res_got3, cupy.ndarray)
    assert np.allclose(res_exp, res_got3.get())
def test_theano_with_constants(constants):
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check theano
    res_got = expr(var, backend='theano')
    assert all(array is None or infer_backend(array) == 'theano'
               for array in expr._evaluated_constants['theano'])
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check theano call returns theano still
    res_got3 = expr(backends.to_theano(var))
    assert isinstance(res_got3, theano.tensor.TensorVariable)
def test_jax_with_constants(constants):  # pragma: no cover
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check jax
    res_got = expr(var, backend='jax')
    # check jax versions of constants exist
    assert all(array is None or infer_backend(array) == 'jax'
               for array in expr._evaluated_constants['jax'])

    assert np.allclose(res_exp, res_got)
def test_auto_backend_custom_array_no_tensordot():
    x = Shaped((1, 2, 3))
    # Shaped is an array-like object defined by opt_einsum - which has no TDOT
    assert infer_backend(x) == 'opt_einsum'
    assert parse_backend([x], 'auto') == 'numpy'