def test_arrayexpr_push_indices_up_and_down():

    indices = list(range(12))

    contr_diag_indices = [(0, 6), (2, 8)]
    assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (1, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15)
    assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (None, 0, None, 1, 2, 3, None, 4, None, 5, 6, 7)

    assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (1, 3, 4, 5, 7, 9, (0, 6), (2, 8), None, None, None, None)
    assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (6, 0, 7, 1, 2, 3, 6, 4, 7, 5, None, None)

    contr_diag_indices = [(1, 2), (7, 8)]
    assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (0, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15)
    assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (0, None, None, 1, 2, 3, 4, None, None, 5, 6, 7)

    assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (0, 3, 4, 5, 6, 9, (1, 2), (7, 8), None, None, None, None)
    assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (0, 6, 6, 1, 2, 3, 4, 7, 7, 5, None, None)
def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]):
    mapping = _get_mapping_from_subranks(expr.subranks)

    editor: _EditArrayContraction
    if isinstance(expr, ArrayContraction):
        editor = _EditArrayContraction(expr)
    elif isinstance(expr, ArrayDiagonal):
        if isinstance(expr.expr, ArrayContraction):
            editor = _EditArrayContraction(expr.expr)
            diagonalized = ArrayContraction._push_indices_down(
                expr.expr.contraction_indices, expr.diagonal_indices)
        elif isinstance(expr.expr, ArrayTensorProduct):
            editor = _EditArrayContraction(None)
            editor.args_with_ind = [
                _ArgE(arg) for i, arg in enumerate(expr.expr.args)
            ]
            diagonalized = expr.diagonal_indices
        else:
            raise NotImplementedError("not implemented")

        # Trick: add diagonalized indices as negative indices into the editor object:
        for i, e in enumerate(diagonalized):
            for j in e:
                arg_pos, rel_pos = mapping[j]
                editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i

    map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list)
    map_ind_to_inds = defaultdict(int)
    for arg_with_ind in editor.args_with_ind:
        for ind in arg_with_ind.indices:
            map_ind_to_inds[ind] += 1
        if None in arg_with_ind.indices:
            continue
        map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)

    k: FrozenSet[int]
    v: List[_ArgE]
    for k, v in map_contr_to_args.items():
        if len(k) != 2:
            # Hadamard product only defined for matrices:
            continue
        if len(v) == 1:
            # Hadamard product with a single argument makes no sense:
            continue
        for ind in k:
            if map_ind_to_inds[ind] <= 2:
                # There is no other contraction, skip:
                continue

        # Check if expression is a trace:
        if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]):
            # This is a trace
            continue

        # This is a Hadamard product:

        def check_transpose(x):
            x = [i if i >= 0 else -1 - i for i in x]
            return x == sorted(x)

        hp = hadamard_product(*[
            i.element if check_transpose(i.indices) else Transpose(i.element)
            for i in v
        ])
        hp_indices = v[0].indices
        if not check_transpose(v[0].indices):
            hp_indices = list(reversed(hp_indices))
        editor.insert_after(v[0], _ArgE(hp, hp_indices))
        for i in v:
            editor.args_with_ind.remove(i)

    # Count the ranks of the arguments:
    counter = 0
    # Create a collector for the new diagonal indices:
    diag_indices = defaultdict(list)

    count_index_freq = Counter()
    for arg_with_ind in editor.args_with_ind:
        count_index_freq.update(Counter(arg_with_ind.indices))

    free_index_count = count_index_freq[None]

    # Construct the inverse permutation:
    inv_perm1 = []
    inv_perm2 = []
    # Keep track of which diagonal indices have already been processed:
    done = set([])

    # Counter for the diagonal indices:
    counter4 = 0

    for arg_with_ind in editor.args_with_ind:
        # If some diagonalization axes have been removed, they should be
        # permuted in order to keep the permutation.
        # Add permutation here
        counter2 = 0  # counter for the indices
        for i in arg_with_ind.indices:
            if i is None:
                inv_perm1.append(counter4)
                counter2 += 1
                counter4 += 1
                continue
            if i >= 0:
                continue
            # Reconstruct the diagonal indices:
            diag_indices[-1 - i].append(counter + counter2)
            if count_index_freq[i] == 1 and i not in done:
                inv_perm1.append(free_index_count - 1 - i)
                done.add(i)
            elif i not in done:
                inv_perm2.append(free_index_count - 1 - i)
                done.add(i)
            counter2 += 1
        # Remove negative indices to restore a proper editor object:
        arg_with_ind.indices = [
            i if i is not None and i >= 0 else None
            for i in arg_with_ind.indices
        ]
        counter += len([i for i in arg_with_ind.indices if i is None or i < 0])

    inverse_permutation = inv_perm1 + inv_perm2
    permutation = _af_invert(inverse_permutation)

    if isinstance(expr, ArrayContraction):
        return editor.to_array_contraction()
    else:
        # Get the diagonal indices after the detection of HadamardProduct in the expression:
        diag_indices_filtered = [
            tuple(v) for v in diag_indices.values() if len(v) > 1
        ]

        expr1 = editor.to_array_contraction()
        expr2 = ArrayDiagonal(expr1, *diag_indices_filtered)
        expr3 = PermuteDims(expr2, permutation)
        return expr3