def _batch_cont( t1: Tensor, t2: Tensor, tensors: List[Tensor], network_structure: List[List], con_order: List, common_batch_labels: Set, labels_t1: List, labels_t2: List, backend_obj: AbstractBackend) -> Tuple[Tensor, List[List], List]: """ Subroutine for performing a batched contraction of tensors `t1` and `t2`. Args: t1: A Tensor. t2: A Tensor. tensors: List of Tensor objects. network_structure: The canonical labels of the networks. con_order: Array of contracted labels. common_batch_labels: The common batch labels of `t1` and `t2`. labels_t1: The labels of `t1` labels_t2: The labels of `t2` backend_obj: A backend object. Returns: List[Tensor]: Updated list of tensors. List[List]: Updated `network_structure`. List: Updated `con_order` (contraction order). """ common_batch_labels = list(common_batch_labels) #find positions of common batch labels t1_batch_pos = [labels_t1.index(l) for l in common_batch_labels] t2_batch_pos = [labels_t2.index(l) for l in common_batch_labels] #find positions of contracted non-batch labels non_batch_labels_t1 = { l for l in labels_t1 if l not in common_batch_labels } non_batch_labels_t2 = { l for l in labels_t2 if l not in common_batch_labels } common_contracted_labels = list( non_batch_labels_t1.intersection(non_batch_labels_t2)) t1_cont = [labels_t1.index(l) for l in common_contracted_labels] t2_cont = [labels_t2.index(l) for l in common_contracted_labels] free_labels_t1 = set(labels_t1) - set(common_contracted_labels) - set( common_batch_labels) free_labels_t2 = set(labels_t2) - set(common_contracted_labels) - set( common_batch_labels) # find positions of uncontracted non-batch labels free_pos_t1 = [n for n, l in enumerate(labels_t1) if l in free_labels_t1] free_pos_t2 = [n for n, l in enumerate(labels_t2) if l in free_labels_t2] t1_shape = np.array(backend_obj.shape_tuple(t1)) t2_shape = np.array(backend_obj.shape_tuple(t2)) newshape_t1 = (np.prod(t1_shape[t1_batch_pos]), np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont])) newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]), np.prod(t2_shape[free_pos_t2])) #bring batch labels to the front order_t1 = tuple(t1_batch_pos + free_pos_t1 + t1_cont) order_t2 = tuple(t2_batch_pos + t2_cont + free_pos_t2) mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1), newshape_t1) mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2), newshape_t2) result = backend_obj.matmul(mat1, mat2) final_shape = tuple( np.concatenate([ t1_shape[t1_batch_pos], t1_shape[free_pos_t1], t2_shape[free_pos_t2] ])) result = backend_obj.reshape(result, final_shape) # update labels, tensors, network_structure and con_order new_labels = [labels_t1[i] for i in t1_batch_pos] + [ labels_t1[i] for i in free_pos_t1 ] + [labels_t2[i] for i in free_pos_t2] network_structure.append(new_labels) tensors.append(result) con_order = [c for c in con_order if c not in common_contracted_labels] return tensors, network_structure, con_order
def _batch_cont( t1: Tensor, t2: Tensor, tensors: List[Tensor], network_structure: List[np.ndarray], con_order: np.ndarray, common_batch_labels: np.ndarray, labels_t1: np.ndarray, labels_t2: np.ndarray, backend_obj: AbstractBackend ) -> Tuple[Tensor, List[np.ndarray], np.ndarray]: """ Subroutine for performing a batched contraction of tensors `t1` and `t2`. Args: t1: A Tensor. t2: A Tensor. tensors: List of Tensor objects. network_structure: The canonical labels of the networks. con_order: Array of contracted labels. common_batch_labels: The common batch labels of `t1` and `t2`. labels_t1: The labels of `t1` labels_t2: The labels of `t2` backend_obj: A backend object. Returns: List[Tensor]: Updated list of tensors. List[np.ndarray]: Updated `network_structure`. np.ndarray: Updated `con_order` (contraction order). """ #find positions of common batch labels _, _, t1_batch_pos = np.intersect1d(common_batch_labels, labels_t1, assume_unique=True, return_indices=True) _, _, t2_batch_pos = np.intersect1d(common_batch_labels, labels_t2, assume_unique=True, return_indices=True) #find positions of contracted non-batch labels non_batch_labels_t1 = labels_t1[np.logical_not( np.isin(labels_t1, common_batch_labels))] non_batch_labels_t2 = labels_t2[np.logical_not( np.isin(labels_t2, common_batch_labels))] common_contracted_labels = np.intersect1d(non_batch_labels_t1, non_batch_labels_t2, assume_unique=True) _, _, t1_cont = np.intersect1d(common_contracted_labels, labels_t1, assume_unique=True, return_indices=True) _, _, t2_cont = np.intersect1d(common_contracted_labels, labels_t2, assume_unique=True, return_indices=True) # find positions of uncontracted non-batch labels free_pos_t1 = np.setdiff1d(np.arange(len(labels_t1)), np.append(t1_cont, t1_batch_pos)) free_pos_t2 = np.setdiff1d(np.arange(len(labels_t2)), np.append(t2_cont, t2_batch_pos)) t1_shape = np.array(backend_obj.shape_tuple(t1)) t2_shape = np.array(backend_obj.shape_tuple(t2)) newshape_t1 = (np.prod(t1_shape[t1_batch_pos]), np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont])) newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]), np.prod(t2_shape[free_pos_t2])) #bring batch labels to the front order_t1 = tuple(np.concatenate([t1_batch_pos, free_pos_t1, t1_cont])) order_t2 = tuple(np.concatenate([t2_batch_pos, t2_cont, free_pos_t2])) mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1), newshape_t1) mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2), newshape_t2) result = backend_obj.matmul(mat1, mat2) final_shape = tuple( np.concatenate([ t1_shape[t1_batch_pos], t1_shape[free_pos_t1], t2_shape[free_pos_t2] ])) result = backend_obj.reshape(result, final_shape) # update labels, tensors, network_structure and con_order new_labels = np.concatenate([ labels_t1[t1_batch_pos], labels_t1[free_pos_t1], labels_t2[free_pos_t2] ]) network_structure.append(new_labels) tensors.append(result) if len(con_order) > 0: con_order = np.delete( con_order, np.intersect1d(common_contracted_labels, con_order, assume_unique=True, return_indices=True)[2]) return tensors, network_structure, con_order