def test_edit_array_contraction(): cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5)) ecg = _EditArrayContraction(cg) assert ecg.to_array_contraction() == cg ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1] assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4)) ci = ecg.get_new_contraction_index() new_arg = _ArgE(X) new_arg.indices = [ci, ci] ecg.args_with_ind.insert(2, new_arg) assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5)) assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]] assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]] assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]] assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]] raises(ValueError, lambda: ecg.get_mapping_for_index(2)) ecg.args_with_ind.pop(1) assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3)) ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0] ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0] assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4)) ecg.insert_after(ecg.args_with_ind[1], _ArgE(C)) assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6))
def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction): editor = _EditArrayContraction(expr) editor.track_permutation_start() removed: List[int] = [] diag_index_counter: int = 0 for i in range(editor.number_of_contraction_indices): identities = [] args = [] for j, arg in enumerate(editor.args_with_ind): if i not in arg.indices: continue if isinstance(arg.element, Identity): identities.append(arg) else: args.append(arg) if len(identities) == 0: continue if len(args) + len(identities) < 3: continue new_diag_ind = -1 - diag_index_counter diag_index_counter += 1 # Variable "flag" to control whether to skip this contraction set: flag: bool = True for i1, id1 in enumerate(identities): if None not in id1.indices: flag = True break free_pos = list(range(*editor.get_absolute_free_range(id1)))[0] editor._track_permutation[-1].append(free_pos) # type: ignore id1.element = None flag = False break if flag: continue for arg in identities[:i1] + identities[i1 + 1:]: arg.element = None removed.extend(range(*editor.get_absolute_free_range(arg))) for arg in args: arg.indices = [new_diag_ind if j == i else j for j in arg.indices] for j, e in enumerate(editor.args_with_ind): if e.element is None: editor._track_permutation[j] = None # type: ignore editor._track_permutation = [ i for i in editor._track_permutation if i is not None ] # type: ignore # Renumber permutation array form in order to deal with deleted positions: remap = { e: i for i, e in enumerate( sorted({k for j in editor._track_permutation for k in j})) } editor._track_permutation = [[remap[j] for j in i] for i in editor._track_permutation] editor.args_with_ind = [ i for i in editor.args_with_ind if i.element is not None ] new_expr = editor.to_array_contraction() return new_expr, removed
def _support_function_tp1_recognize(contraction_indices, args): if len(contraction_indices) == 0: return _a2m_tensor_product(*args) ac = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices) editor = _EditArrayContraction(ac) editor.track_permutation_start() while True: flag_stop: bool = True for i, arg_with_ind in enumerate(editor.args_with_ind): if not isinstance(arg_with_ind.element, MatrixExpr): continue first_index = arg_with_ind.indices[0] second_index = arg_with_ind.indices[1] first_frequency = editor.count_args_with_index(first_index) second_frequency = editor.count_args_with_index(second_index) if first_index is not None and first_frequency == 1 and first_index == second_index: flag_stop = False arg_with_ind.element = Trace(arg_with_ind.element)._normalize() arg_with_ind.indices = [] break scan_indices = [] if first_frequency == 2: scan_indices.append(first_index) if second_frequency == 2: scan_indices.append(second_index) candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction( scan_indices, editor.args_with_ind[i + 1:]) if candidate is not None: flag_stop = False editor.track_permutation_merge(arg_with_ind, candidate) transpose1 = found_index == first_index new_arge, other_index = _insert_candidate_into_editor( editor, arg_with_ind, candidate, transpose1, transpose) if found_index == first_index: new_arge.indices = [second_index, other_index] else: new_arge.indices = [first_index, other_index] set_indices = set(new_arge.indices) if len(set_indices) == 1 and set_indices != {None}: # This is a trace: new_arge.element = Trace(new_arge.element)._normalize() new_arge.indices = [] editor.args_with_ind[i] = new_arge # TODO: is this break necessary? break if flag_stop: break editor.refresh_indices() return editor.to_array_contraction()
def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal): assert isinstance(expr, ArrayDiagonal) editor = _EditArrayContraction(expr) mapping = { i: {j for j in editor.args_with_ind if i in j.indices} for i in range(-1, -1 - editor.number_of_diagonal_indices, -1) } removed = [] counter: int = 0 for i, arg_with_ind in enumerate(editor.args_with_ind): counter += len(arg_with_ind.indices) if isinstance(arg_with_ind.element, Identity): if None in arg_with_ind.indices and any( i is not None and (i < 0) == True for i in arg_with_ind.indices): diag_ind = [j for j in arg_with_ind.indices if j is not None][0] other = [j for j in mapping[diag_ind] if j != arg_with_ind][0] if not isinstance(other.element, MatrixExpr): continue if 1 not in other.element.shape: continue if None not in other.indices: continue editor.args_with_ind[i].element = None none_index = other.indices.index(None) other.element = DiagMatrix(other.element) other_range = editor.get_absolute_range(other) removed.extend([other_range[0] + none_index]) editor.args_with_ind = [ i for i in editor.args_with_ind if i.element is not None ] removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, get_rank(expr.expr)) return editor.to_array_contraction(), removed
def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]): mapping = _get_mapping_from_subranks(expr.subranks) editor: _EditArrayContraction if isinstance(expr, ArrayContraction): editor = _EditArrayContraction(expr) elif isinstance(expr, ArrayDiagonal): if isinstance(expr.expr, ArrayContraction): editor = _EditArrayContraction(expr.expr) diagonalized = ArrayContraction._push_indices_down( expr.expr.contraction_indices, expr.diagonal_indices) elif isinstance(expr.expr, ArrayTensorProduct): editor = _EditArrayContraction(None) editor.args_with_ind = [ _ArgE(arg) for i, arg in enumerate(expr.expr.args) ] diagonalized = expr.diagonal_indices else: raise NotImplementedError("not implemented") # Trick: add diagonalized indices as negative indices into the editor object: for i, e in enumerate(diagonalized): for j in e: arg_pos, rel_pos = mapping[j] editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list) map_ind_to_inds = defaultdict(int) for arg_with_ind in editor.args_with_ind: for ind in arg_with_ind.indices: map_ind_to_inds[ind] += 1 if None in arg_with_ind.indices: continue map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) k: FrozenSet[int] v: List[_ArgE] for k, v in map_contr_to_args.items(): if len(k) != 2: # Hadamard product only defined for matrices: continue if len(v) == 1: # Hadamard product with a single argument makes no sense: continue for ind in k: if map_ind_to_inds[ind] <= 2: # There is no other contraction, skip: continue # Check if expression is a trace: if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]): # This is a trace continue # This is a Hadamard product: def check_transpose(x): x = [i if i >= 0 else -1 - i for i in x] return x == sorted(x) hp = hadamard_product(*[ i.element if check_transpose(i.indices) else Transpose(i.element) for i in v ]) hp_indices = v[0].indices if not check_transpose(v[0].indices): hp_indices = list(reversed(hp_indices)) editor.insert_after(v[0], _ArgE(hp, hp_indices)) for i in v: editor.args_with_ind.remove(i) # Count the ranks of the arguments: counter = 0 # Create a collector for the new diagonal indices: diag_indices = defaultdict(list) count_index_freq = Counter() for arg_with_ind in editor.args_with_ind: count_index_freq.update(Counter(arg_with_ind.indices)) free_index_count = count_index_freq[None] # Construct the inverse permutation: inv_perm1 = [] inv_perm2 = [] # Keep track of which diagonal indices have already been processed: done = set([]) # Counter for the diagonal indices: counter4 = 0 for arg_with_ind in editor.args_with_ind: # If some diagonalization axes have been removed, they should be # permuted in order to keep the permutation. # Add permutation here counter2 = 0 # counter for the indices for i in arg_with_ind.indices: if i is None: inv_perm1.append(counter4) counter2 += 1 counter4 += 1 continue if i >= 0: continue # Reconstruct the diagonal indices: diag_indices[-1 - i].append(counter + counter2) if count_index_freq[i] == 1 and i not in done: inv_perm1.append(free_index_count - 1 - i) done.add(i) elif i not in done: inv_perm2.append(free_index_count - 1 - i) done.add(i) counter2 += 1 # Remove negative indices to restore a proper editor object: arg_with_ind.indices = [ i if i is not None and i >= 0 else None for i in arg_with_ind.indices ] counter += len([i for i in arg_with_ind.indices if i is None or i < 0]) inverse_permutation = inv_perm1 + inv_perm2 permutation = _af_invert(inverse_permutation) if isinstance(expr, ArrayContraction): return editor.to_array_contraction() else: # Get the diagonal indices after the detection of HadamardProduct in the expression: diag_indices_filtered = [ tuple(v) for v in diag_indices.values() if len(v) > 1 ] expr1 = editor.to_array_contraction() expr2 = ArrayDiagonal(expr1, *diag_indices_filtered) expr3 = PermuteDims(expr2, permutation) return expr3
def remove_identity_matrices(expr: ArrayContraction): editor = _EditArrayContraction(expr) removed: List[int] = [] permutation_map = {} free_indices = list( accumulate([0] + [ sum([i is None for i in arg.indices]) for arg in editor.args_with_ind ])) free_map = {k: v for k, v in zip(editor.args_with_ind, free_indices[:-1])} update_pairs = {} for ind in range(editor.number_of_contraction_indices): args = editor.get_args_with_index(ind) identity_matrices = [ i for i in args if isinstance(i.element, Identity) ] number_identity_matrices = len(identity_matrices) # If the contraction involves a non-identity matrix and multiple identity matrices: if number_identity_matrices != len( args) - 1 or number_identity_matrices == 0: continue # Get the non-identity element: non_identity = [ i for i in args if not isinstance(i.element, Identity) ][0] # Check that all identity matrices have at least one free index # (otherwise they would be contractions to some other elements) if any([None not in i.indices for i in identity_matrices]): continue # Mark the identity matrices for removal: for i in identity_matrices: i.element = None removed.extend( range(free_map[i], free_map[i] + len([j for j in i.indices if j is None]))) last_removed = removed.pop(-1) update_pairs[last_removed, ind] = non_identity.indices[:] # Remove the indices from the non-identity matrix, as the contraction # no longer exists: non_identity.indices = [ None if i == ind else i for i in non_identity.indices ] removed.sort() shifts = list( accumulate([1 if i in removed else 0 for i in range(get_rank(expr))])) for (last_removed, ind), non_identity_indices in update_pairs.items(): pos = [ free_map[non_identity] + i for i, e in enumerate(non_identity_indices) if e == ind ] assert len(pos) == 1 for j in pos: permutation_map[j] = last_removed editor.args_with_ind = [ i for i in editor.args_with_ind if i.element is not None ] ret_expr = editor.to_array_contraction() permutation = [] counter = 0 counter2 = 0 for j in range(get_rank(expr)): if j in removed: continue if counter2 in permutation_map: target = permutation_map[counter2] permutation.append(target - shifts[target]) counter2 += 1 else: while counter in permutation_map.values(): counter += 1 permutation.append(counter) counter += 1 counter2 += 1 ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation)) return ret_expr2, removed
def identify_removable_identity_matrices(expr): editor = _EditArrayContraction(expr) flag: bool = True while flag: flag = False for arg_with_ind in editor.args_with_ind: if isinstance(arg_with_ind.element, Identity): k = arg_with_ind.element.shape[0] # Candidate for removal: if arg_with_ind.indices == [None, None]: # Free identity matrix, will be cleared by _remove_trivial_dims: continue elif None in arg_with_ind.indices: ind = [j for j in arg_with_ind.indices if j is not None][0] counted = editor.count_args_with_index(ind) if counted == 1: # Identity matrix contracted only on one index with itself, # transform to a OneArray(k) element: editor.insert_after(arg_with_ind, OneArray(k)) editor.args_with_ind.remove(arg_with_ind) flag = True break elif counted > 2: # Case counted = 2 is a matrix multiplication by identity matrix, skip it. # Case counted > 2 is a multiple contraction, # this is a case where the contraction becomes a diagonalization if the # identity matrix is dropped. continue elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: ind = arg_with_ind.indices[0] counted = editor.count_args_with_index(ind) if counted > 1: editor.args_with_ind.remove(arg_with_ind) flag = True break else: # This is a trace, skip it as it will be recognized somewhere else: pass elif ask(Q.diagonal(arg_with_ind.element)): if arg_with_ind.indices == [None, None]: continue elif None in arg_with_ind.indices: pass elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: ind = arg_with_ind.indices[0] counted = editor.count_args_with_index(ind) if counted == 3: # A_ai B_bi D_ii ==> A_ai D_ij B_bj ind_new = editor.get_new_contraction_index() other_args = [ j for j in editor.args_with_ind if j != arg_with_ind ] other_args[1].indices = [ ind_new if j == ind else j for j in other_args[1].indices ] arg_with_ind.indices = [ind, ind_new] flag = True break return editor.to_array_contraction()
def identify_hadamard_products(expr: tUnion[ArrayContraction, ArrayDiagonal]): editor: _EditArrayContraction = _EditArrayContraction(expr) map_contr_to_args: tDict[FrozenSet, List[_ArgE]] = defaultdict(list) map_ind_to_inds: tDict[Optional[int], int] = defaultdict(int) for arg_with_ind in editor.args_with_ind: for ind in arg_with_ind.indices: map_ind_to_inds[ind] += 1 if None in arg_with_ind.indices: continue map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) k: FrozenSet[int] v: List[_ArgE] for k, v in map_contr_to_args.items(): make_trace: bool = False if len(k) == 1 and next(iter(k)) >= 0 and sum( [next(iter(k)) in i for i in map_contr_to_args]) == 1: # This is a trace: the arguments are fully contracted with only one # index, and the index isn't used anywhere else: make_trace = True first_element = S.One elif len(k) != 2: # Hadamard product only defined for matrices: continue if len(v) == 1: # Hadamard product with a single argument makes no sense: continue for ind in k: if map_ind_to_inds[ind] <= 2: # There is no other contraction, skip: continue def check_transpose(x): x = [i if i >= 0 else -1 - i for i in x] return x == sorted(x) # Check if expression is a trace: if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]) and all([j >= 0 for j in k]): # This is a trace make_trace = True first_element = v[0].element if not check_transpose(v[0].indices): first_element = first_element.T hadamard_factors = v[1:] else: hadamard_factors = v # This is a Hadamard product: hp = hadamard_product(*[ i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors ]) hp_indices = v[0].indices if not check_transpose(hadamard_factors[0].indices): hp_indices = list(reversed(hp_indices)) if make_trace: hp = Trace(first_element * hp.T)._normalize() hp_indices = [] editor.insert_after(v[0], _ArgE(hp, hp_indices)) for i in v: editor.args_with_ind.remove(i) return editor.to_array_contraction()