def test_edit_array_contraction():
    cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5))
    ecg = _EditArrayContraction(cg)
    assert ecg.to_array_contraction() == cg

    ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1]
    assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4))

    ci = ecg.get_new_contraction_index()
    new_arg = _ArgE(X)
    new_arg.indices = [ci, ci]
    ecg.args_with_ind.insert(2, new_arg)
    assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5))

    assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]]
    assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]]
    assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]]
    assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]]
    raises(ValueError, lambda: ecg.get_mapping_for_index(2))

    ecg.args_with_ind.pop(1)
    assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3))

    ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0]
    ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0]
    assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4))

    ecg.insert_after(ecg.args_with_ind[1], _ArgE(C))
    assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6))
示例#2
0
def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction):
    editor = _EditArrayContraction(expr)
    editor.track_permutation_start()
    removed: List[int] = []
    diag_index_counter: int = 0
    for i in range(editor.number_of_contraction_indices):
        identities = []
        args = []
        for j, arg in enumerate(editor.args_with_ind):
            if i not in arg.indices:
                continue
            if isinstance(arg.element, Identity):
                identities.append(arg)
            else:
                args.append(arg)
        if len(identities) == 0:
            continue
        if len(args) + len(identities) < 3:
            continue
        new_diag_ind = -1 - diag_index_counter
        diag_index_counter += 1
        # Variable "flag" to control whether to skip this contraction set:
        flag: bool = True
        for i1, id1 in enumerate(identities):
            if None not in id1.indices:
                flag = True
                break
            free_pos = list(range(*editor.get_absolute_free_range(id1)))[0]
            editor._track_permutation[-1].append(free_pos)  # type: ignore
            id1.element = None
            flag = False
            break
        if flag:
            continue
        for arg in identities[:i1] + identities[i1 + 1:]:
            arg.element = None
            removed.extend(range(*editor.get_absolute_free_range(arg)))
        for arg in args:
            arg.indices = [new_diag_ind if j == i else j for j in arg.indices]
    for j, e in enumerate(editor.args_with_ind):
        if e.element is None:
            editor._track_permutation[j] = None  # type: ignore
    editor._track_permutation = [
        i for i in editor._track_permutation if i is not None
    ]  # type: ignore
    # Renumber permutation array form in order to deal with deleted positions:
    remap = {
        e: i
        for i, e in enumerate(
            sorted({k
                    for j in editor._track_permutation for k in j}))
    }
    editor._track_permutation = [[remap[j] for j in i]
                                 for i in editor._track_permutation]
    editor.args_with_ind = [
        i for i in editor.args_with_ind if i.element is not None
    ]
    new_expr = editor.to_array_contraction()
    return new_expr, removed
示例#3
0
def _support_function_tp1_recognize(contraction_indices, args):
    if len(contraction_indices) == 0:
        return _a2m_tensor_product(*args)

    ac = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices)
    editor = _EditArrayContraction(ac)
    editor.track_permutation_start()

    while True:
        flag_stop: bool = True
        for i, arg_with_ind in enumerate(editor.args_with_ind):
            if not isinstance(arg_with_ind.element, MatrixExpr):
                continue

            first_index = arg_with_ind.indices[0]
            second_index = arg_with_ind.indices[1]

            first_frequency = editor.count_args_with_index(first_index)
            second_frequency = editor.count_args_with_index(second_index)

            if first_index is not None and first_frequency == 1 and first_index == second_index:
                flag_stop = False
                arg_with_ind.element = Trace(arg_with_ind.element)._normalize()
                arg_with_ind.indices = []
                break

            scan_indices = []
            if first_frequency == 2:
                scan_indices.append(first_index)
            if second_frequency == 2:
                scan_indices.append(second_index)

            candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(
                scan_indices, editor.args_with_ind[i + 1:])
            if candidate is not None:
                flag_stop = False
                editor.track_permutation_merge(arg_with_ind, candidate)
                transpose1 = found_index == first_index
                new_arge, other_index = _insert_candidate_into_editor(
                    editor, arg_with_ind, candidate, transpose1, transpose)
                if found_index == first_index:
                    new_arge.indices = [second_index, other_index]
                else:
                    new_arge.indices = [first_index, other_index]
                set_indices = set(new_arge.indices)
                if len(set_indices) == 1 and set_indices != {None}:
                    # This is a trace:
                    new_arge.element = Trace(new_arge.element)._normalize()
                    new_arge.indices = []
                editor.args_with_ind[i] = new_arge
                # TODO: is this break necessary?
                break

        if flag_stop:
            break

    editor.refresh_indices()
    return editor.to_array_contraction()
示例#4
0
def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal):
    assert isinstance(expr, ArrayDiagonal)
    editor = _EditArrayContraction(expr)
    mapping = {
        i: {j
            for j in editor.args_with_ind if i in j.indices}
        for i in range(-1, -1 - editor.number_of_diagonal_indices, -1)
    }
    removed = []
    counter: int = 0
    for i, arg_with_ind in enumerate(editor.args_with_ind):
        counter += len(arg_with_ind.indices)
        if isinstance(arg_with_ind.element, Identity):
            if None in arg_with_ind.indices and any(
                    i is not None and (i < 0) == True
                    for i in arg_with_ind.indices):
                diag_ind = [j for j in arg_with_ind.indices
                            if j is not None][0]
                other = [j for j in mapping[diag_ind] if j != arg_with_ind][0]
                if not isinstance(other.element, MatrixExpr):
                    continue
                if 1 not in other.element.shape:
                    continue
                if None not in other.indices:
                    continue
                editor.args_with_ind[i].element = None
                none_index = other.indices.index(None)
                other.element = DiagMatrix(other.element)
                other_range = editor.get_absolute_range(other)
                removed.extend([other_range[0] + none_index])
    editor.args_with_ind = [
        i for i in editor.args_with_ind if i.element is not None
    ]
    removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed,
                                             get_rank(expr.expr))
    return editor.to_array_contraction(), removed
示例#5
0
def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]):
    mapping = _get_mapping_from_subranks(expr.subranks)

    editor: _EditArrayContraction
    if isinstance(expr, ArrayContraction):
        editor = _EditArrayContraction(expr)
    elif isinstance(expr, ArrayDiagonal):
        if isinstance(expr.expr, ArrayContraction):
            editor = _EditArrayContraction(expr.expr)
            diagonalized = ArrayContraction._push_indices_down(
                expr.expr.contraction_indices, expr.diagonal_indices)
        elif isinstance(expr.expr, ArrayTensorProduct):
            editor = _EditArrayContraction(None)
            editor.args_with_ind = [
                _ArgE(arg) for i, arg in enumerate(expr.expr.args)
            ]
            diagonalized = expr.diagonal_indices
        else:
            raise NotImplementedError("not implemented")

        # Trick: add diagonalized indices as negative indices into the editor object:
        for i, e in enumerate(diagonalized):
            for j in e:
                arg_pos, rel_pos = mapping[j]
                editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i

    map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list)
    map_ind_to_inds = defaultdict(int)
    for arg_with_ind in editor.args_with_ind:
        for ind in arg_with_ind.indices:
            map_ind_to_inds[ind] += 1
        if None in arg_with_ind.indices:
            continue
        map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)

    k: FrozenSet[int]
    v: List[_ArgE]
    for k, v in map_contr_to_args.items():
        if len(k) != 2:
            # Hadamard product only defined for matrices:
            continue
        if len(v) == 1:
            # Hadamard product with a single argument makes no sense:
            continue
        for ind in k:
            if map_ind_to_inds[ind] <= 2:
                # There is no other contraction, skip:
                continue

        # Check if expression is a trace:
        if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]):
            # This is a trace
            continue

        # This is a Hadamard product:

        def check_transpose(x):
            x = [i if i >= 0 else -1 - i for i in x]
            return x == sorted(x)

        hp = hadamard_product(*[
            i.element if check_transpose(i.indices) else Transpose(i.element)
            for i in v
        ])
        hp_indices = v[0].indices
        if not check_transpose(v[0].indices):
            hp_indices = list(reversed(hp_indices))
        editor.insert_after(v[0], _ArgE(hp, hp_indices))
        for i in v:
            editor.args_with_ind.remove(i)

    # Count the ranks of the arguments:
    counter = 0
    # Create a collector for the new diagonal indices:
    diag_indices = defaultdict(list)

    count_index_freq = Counter()
    for arg_with_ind in editor.args_with_ind:
        count_index_freq.update(Counter(arg_with_ind.indices))

    free_index_count = count_index_freq[None]

    # Construct the inverse permutation:
    inv_perm1 = []
    inv_perm2 = []
    # Keep track of which diagonal indices have already been processed:
    done = set([])

    # Counter for the diagonal indices:
    counter4 = 0

    for arg_with_ind in editor.args_with_ind:
        # If some diagonalization axes have been removed, they should be
        # permuted in order to keep the permutation.
        # Add permutation here
        counter2 = 0  # counter for the indices
        for i in arg_with_ind.indices:
            if i is None:
                inv_perm1.append(counter4)
                counter2 += 1
                counter4 += 1
                continue
            if i >= 0:
                continue
            # Reconstruct the diagonal indices:
            diag_indices[-1 - i].append(counter + counter2)
            if count_index_freq[i] == 1 and i not in done:
                inv_perm1.append(free_index_count - 1 - i)
                done.add(i)
            elif i not in done:
                inv_perm2.append(free_index_count - 1 - i)
                done.add(i)
            counter2 += 1
        # Remove negative indices to restore a proper editor object:
        arg_with_ind.indices = [
            i if i is not None and i >= 0 else None
            for i in arg_with_ind.indices
        ]
        counter += len([i for i in arg_with_ind.indices if i is None or i < 0])

    inverse_permutation = inv_perm1 + inv_perm2
    permutation = _af_invert(inverse_permutation)

    if isinstance(expr, ArrayContraction):
        return editor.to_array_contraction()
    else:
        # Get the diagonal indices after the detection of HadamardProduct in the expression:
        diag_indices_filtered = [
            tuple(v) for v in diag_indices.values() if len(v) > 1
        ]

        expr1 = editor.to_array_contraction()
        expr2 = ArrayDiagonal(expr1, *diag_indices_filtered)
        expr3 = PermuteDims(expr2, permutation)
        return expr3
示例#6
0
def remove_identity_matrices(expr: ArrayContraction):
    editor = _EditArrayContraction(expr)
    removed: List[int] = []

    permutation_map = {}

    free_indices = list(
        accumulate([0] + [
            sum([i is None for i in arg.indices])
            for arg in editor.args_with_ind
        ]))
    free_map = {k: v for k, v in zip(editor.args_with_ind, free_indices[:-1])}

    update_pairs = {}

    for ind in range(editor.number_of_contraction_indices):
        args = editor.get_args_with_index(ind)
        identity_matrices = [
            i for i in args if isinstance(i.element, Identity)
        ]
        number_identity_matrices = len(identity_matrices)
        # If the contraction involves a non-identity matrix and multiple identity matrices:
        if number_identity_matrices != len(
                args) - 1 or number_identity_matrices == 0:
            continue
        # Get the non-identity element:
        non_identity = [
            i for i in args if not isinstance(i.element, Identity)
        ][0]
        # Check that all identity matrices have at least one free index
        # (otherwise they would be contractions to some other elements)
        if any([None not in i.indices for i in identity_matrices]):
            continue
        # Mark the identity matrices for removal:
        for i in identity_matrices:
            i.element = None
            removed.extend(
                range(free_map[i],
                      free_map[i] + len([j for j in i.indices if j is None])))
        last_removed = removed.pop(-1)
        update_pairs[last_removed, ind] = non_identity.indices[:]
        # Remove the indices from the non-identity matrix, as the contraction
        # no longer exists:
        non_identity.indices = [
            None if i == ind else i for i in non_identity.indices
        ]

    removed.sort()

    shifts = list(
        accumulate([1 if i in removed else 0 for i in range(get_rank(expr))]))
    for (last_removed, ind), non_identity_indices in update_pairs.items():
        pos = [
            free_map[non_identity] + i
            for i, e in enumerate(non_identity_indices) if e == ind
        ]
        assert len(pos) == 1
        for j in pos:
            permutation_map[j] = last_removed

    editor.args_with_ind = [
        i for i in editor.args_with_ind if i.element is not None
    ]
    ret_expr = editor.to_array_contraction()
    permutation = []
    counter = 0
    counter2 = 0
    for j in range(get_rank(expr)):
        if j in removed:
            continue
        if counter2 in permutation_map:
            target = permutation_map[counter2]
            permutation.append(target - shifts[target])
            counter2 += 1
        else:
            while counter in permutation_map.values():
                counter += 1
            permutation.append(counter)
            counter += 1
            counter2 += 1
    ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation))
    return ret_expr2, removed
示例#7
0
def identify_removable_identity_matrices(expr):
    editor = _EditArrayContraction(expr)

    flag: bool = True
    while flag:
        flag = False
        for arg_with_ind in editor.args_with_ind:
            if isinstance(arg_with_ind.element, Identity):
                k = arg_with_ind.element.shape[0]
                # Candidate for removal:
                if arg_with_ind.indices == [None, None]:
                    # Free identity matrix, will be cleared by _remove_trivial_dims:
                    continue
                elif None in arg_with_ind.indices:
                    ind = [j for j in arg_with_ind.indices if j is not None][0]
                    counted = editor.count_args_with_index(ind)
                    if counted == 1:
                        # Identity matrix contracted only on one index with itself,
                        # transform to a OneArray(k) element:
                        editor.insert_after(arg_with_ind, OneArray(k))
                        editor.args_with_ind.remove(arg_with_ind)
                        flag = True
                        break
                    elif counted > 2:
                        # Case counted = 2 is a matrix multiplication by identity matrix, skip it.
                        # Case counted > 2 is a multiple contraction,
                        # this is a case where the contraction becomes a diagonalization if the
                        # identity matrix is dropped.
                        continue
                elif arg_with_ind.indices[0] == arg_with_ind.indices[1]:
                    ind = arg_with_ind.indices[0]
                    counted = editor.count_args_with_index(ind)
                    if counted > 1:
                        editor.args_with_ind.remove(arg_with_ind)
                        flag = True
                        break
                    else:
                        # This is a trace, skip it as it will be recognized somewhere else:
                        pass
            elif ask(Q.diagonal(arg_with_ind.element)):
                if arg_with_ind.indices == [None, None]:
                    continue
                elif None in arg_with_ind.indices:
                    pass
                elif arg_with_ind.indices[0] == arg_with_ind.indices[1]:
                    ind = arg_with_ind.indices[0]
                    counted = editor.count_args_with_index(ind)
                    if counted == 3:
                        # A_ai B_bi D_ii ==> A_ai D_ij B_bj
                        ind_new = editor.get_new_contraction_index()
                        other_args = [
                            j for j in editor.args_with_ind
                            if j != arg_with_ind
                        ]
                        other_args[1].indices = [
                            ind_new if j == ind else j
                            for j in other_args[1].indices
                        ]
                        arg_with_ind.indices = [ind, ind_new]
                        flag = True
                        break

    return editor.to_array_contraction()
示例#8
0
def identify_hadamard_products(expr: tUnion[ArrayContraction, ArrayDiagonal]):

    editor: _EditArrayContraction = _EditArrayContraction(expr)

    map_contr_to_args: tDict[FrozenSet, List[_ArgE]] = defaultdict(list)
    map_ind_to_inds: tDict[Optional[int], int] = defaultdict(int)
    for arg_with_ind in editor.args_with_ind:
        for ind in arg_with_ind.indices:
            map_ind_to_inds[ind] += 1
        if None in arg_with_ind.indices:
            continue
        map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)

    k: FrozenSet[int]
    v: List[_ArgE]
    for k, v in map_contr_to_args.items():
        make_trace: bool = False
        if len(k) == 1 and next(iter(k)) >= 0 and sum(
            [next(iter(k)) in i for i in map_contr_to_args]) == 1:
            # This is a trace: the arguments are fully contracted with only one
            # index, and the index isn't used anywhere else:
            make_trace = True
            first_element = S.One
        elif len(k) != 2:
            # Hadamard product only defined for matrices:
            continue
        if len(v) == 1:
            # Hadamard product with a single argument makes no sense:
            continue
        for ind in k:
            if map_ind_to_inds[ind] <= 2:
                # There is no other contraction, skip:
                continue

        def check_transpose(x):
            x = [i if i >= 0 else -1 - i for i in x]
            return x == sorted(x)

        # Check if expression is a trace:
        if all([map_ind_to_inds[j] == len(v) and j >= 0
                for j in k]) and all([j >= 0 for j in k]):
            # This is a trace
            make_trace = True
            first_element = v[0].element
            if not check_transpose(v[0].indices):
                first_element = first_element.T
            hadamard_factors = v[1:]
        else:
            hadamard_factors = v

        # This is a Hadamard product:

        hp = hadamard_product(*[
            i.element if check_transpose(i.indices) else Transpose(i.element)
            for i in hadamard_factors
        ])
        hp_indices = v[0].indices
        if not check_transpose(hadamard_factors[0].indices):
            hp_indices = list(reversed(hp_indices))
        if make_trace:
            hp = Trace(first_element * hp.T)._normalize()
            hp_indices = []
        editor.insert_after(v[0], _ArgE(hp, hp_indices))
        for i in v:
            editor.args_with_ind.remove(i)

    return editor.to_array_contraction()