コード例 #1
0
def test_torch_with_constants(constants):
    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    non_const, = {0, 1, 2} - constants
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[non_const])
    res_exp = contract(eq,
                       *(ops[i] if i in constants else var for i in range(3)))

    expr = contract_expression(eq, *ops, constants=constants)

    # check torch
    res_got = expr(var, backend='torch')
    assert all(array is None or infer_backend(array) == 'torch'
               for array in expr._evaluated_constants['torch'])
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check torch call returns torch still
    res_got3 = expr(backends.to_torch(var))
    assert isinstance(res_got3, torch.Tensor)
    res_got3 = res_got3.numpy(
    ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy()
    assert np.allclose(res_exp, res_got3)
コード例 #2
0
ファイル: test_backends.py プロジェクト: liwt31/opt_einsum
def test_torch_with_constants():

    eq = 'ij,jk,kl->li'
    shapes = (2, 3), (3, 4), (4, 5)
    constants = {0, 2}
    ops = [
        np.random.rand(*shp) if i in constants else shp
        for i, shp in enumerate(shapes)
    ]
    var = np.random.rand(*shapes[1])

    res_exp = contract(eq, ops[0], var, ops[2])

    expr = contract_expression(eq, *ops, constants=constants)

    # check torch
    res_got = expr(var, backend='torch')
    assert 'torch' in expr._evaluated_constants
    assert np.allclose(res_exp, res_got)

    # check can call with numpy still
    res_got2 = expr(var, backend='numpy')
    assert np.allclose(res_exp, res_got2)

    # check torch call returns torch still
    res_got3 = expr(backends.to_torch(var), backend='torch')
    assert isinstance(res_got3, torch.Tensor)
    res_got3 = res_got3.numpy(
    ) if res_got3.device.type == 'cpu' else res_got3.cpu().numpy()
    assert np.allclose(res_exp, res_got3)
コード例 #3
0
def test_torch(string):

    views = helpers.build_views(string)
    ein = contract(string, *views, optimize=False, use_blas=False)
    shps = [v.shape for v in views]

    expr = contract_expression(string, *shps, optimize=True)

    opt = expr(*views, backend='torch')
    assert np.allclose(ein, opt)

    # test non-conversion mode
    torch_views = [backends.to_torch(view) for view in views]
    torch_opt = expr(*torch_views)
    assert isinstance(torch_opt, torch.Tensor)
    assert np.allclose(ein, torch_opt.cpu().numpy())