def test_arrayexpr_permutedims_sink():

    cg = _permute_dims(_array_tensor_product(M, N), [0, 1, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_tensor_product(M, _permute_dims(N, [1, 0]))

    cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0]))

    cg = _permute_dims(_array_tensor_product(M, N), [3, 2, 1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_tensor_product(_permute_dims(N, [1, 0]), _permute_dims(M, [1, 0]))

    cg = _permute_dims(_array_contraction(_array_tensor_product(M, N), (1, 2)), [1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N), [[0, 3]]), (1, 2))

    cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0]))

    cg = _permute_dims(_array_contraction(_array_tensor_product(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N, P), [[0, 5]]), (1, 2), (3, 4))
示例#2
0
def test_arrayexpr_permutedims_sink():

    cg = PermuteDims(ArrayTensorProduct(M, N), [0, 1, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayTensorProduct(M, PermuteDims(N, [1, 0]))

    cg = PermuteDims(ArrayTensorProduct(M, N), [1, 0, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayTensorProduct(PermuteDims(M, [1, 0]), PermuteDims(N, [1, 0]))

    cg = PermuteDims(ArrayTensorProduct(M, N), [3, 2, 1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayTensorProduct(PermuteDims(N, [1, 0]), PermuteDims(M, [1, 0]))

    cg = PermuteDims(ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), [1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayContraction(PermuteDims(ArrayTensorProduct(M, N), [[0, 3]]), (1, 2))

    cg = PermuteDims(ArrayTensorProduct(M, N), [1, 0, 3, 2], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayTensorProduct(PermuteDims(M, [1, 0]), PermuteDims(N, [1, 0]))

    cg = PermuteDims(ArrayContraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False)
    sunk = nest_permutation(cg)
    assert sunk == ArrayContraction(PermuteDims(ArrayTensorProduct(M, N, P), [[0, 5]]), (1, 2), (3, 4))