def test_abstract_backend_outer_product_not_implemented(): backend = AbstractBackend() with pytest.raises(NotImplementedError): backend.outer_product(np.ones((2, 2)), np.ones((2, 2)))
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int], sizes: Tuple[int], con_order: Tuple[int], out_order: Tuple[int], backend_obj: AbstractBackend) -> Tensor: """ Jittable Ncon function. Performs the contraction of `tensors`. Args: tensors: List of tensors. flat_labels: A Tuple of integers. sizes: Tuple of int used to reconstruct `network_structure` from `flat_labels`. con_order: Order of the contraction. out_order: Order of the final axis order. backend_obj: A backend object. Returns: The final tensor after contraction. """ # some jax-juggling to avoid retracing ... slices = np.append(0, np.cumsum(sizes)) network_structure = [ np.array(flat_labels)[slices[n]:slices[n + 1]] for n in range(len(slices) - 1) ] con_order = np.array(con_order) out_order = np.array(out_order) # pylint: disable=unnecessary-comprehension init_con_order = [c for c in con_order] init_network_structure = [list(c) for c in network_structure] # partial trace for n, tensor in enumerate(tensors): tensors[n], network_structure[n], contracted_labels = _partial_trace( tensor, network_structure[n], backend_obj) if len(contracted_labels) > 0: con_order = np.delete( con_order, np.intersect1d(con_order, contracted_labels, return_indices=True, assume_unique=True)[1]) # contracted all positive labels appearing only once in `network_structure` unique_labels, label_cnts = np.unique(np.concatenate(network_structure), return_counts=True) contractable_labels = unique_labels[np.logical_and(label_cnts == 1, unique_labels > 0)] # update con_order if len(contractable_labels) > 0: con_order = np.delete( con_order, np.nonzero(np.isin(con_order, contractable_labels))[0]) # collapse axes of single-labelled tensors locs = [ n for n, labels in enumerate(network_structure) if np.any(np.isin(labels, contractable_labels)) ] for loc in locs: labels = network_structure[loc] contractable_inds = np.nonzero(np.isin(labels, contractable_labels))[0] network_structure[loc] = np.delete(labels, contractable_inds) tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds)) # perform binary and batch contractions skip_counter = 0 unique_labels, label_cnts = np.unique(np.concatenate(network_structure), return_counts=True) mask = np.logical_or(label_cnts > 2, np.logical_and(label_cnts == 2, unique_labels < 0)) batch_labels = unique_labels[mask] batch_cnts = label_cnts[mask] while len(con_order) > 0: # the next index to be contracted cont_ind = con_order[0] if cont_ind in batch_labels: # if its still a batch index then do it later con_order = np.append(np.delete(con_order, 0), cont_ind) skip_counter += 1 # avoid being stuck in an infinite loop if skip_counter > len(con_order): raise ValueError( f"ncon seems stuck in an infinite loop. \n" f"Please check if `con_order` = {init_con_order} is " f"a valid contraction order for \n" f"`network_structure` = {init_network_structure}") continue # find locations of `cont_ind` in `network_structure` locs = [ n for n, labels in enumerate(network_structure) if sum(labels == cont_ind) > 0 ] t2 = tensors.pop(locs[1]) t1 = tensors.pop(locs[0]) labels_t2 = network_structure.pop(locs[1]) labels_t1 = network_structure.pop(locs[0]) common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1, labels_t2, assume_unique=True, return_indices=True) # check if there are batch labels (i.e. labels appearing more than twice # in `network_structure`). common_batch_labels = np.intersect1d(batch_labels, common_labels, assume_unique=True) if common_batch_labels.shape[0] > 0: # case1: both tensors have one or more common batch indices -> use matmul ix, _ = np.nonzero( batch_labels[:, None] == common_batch_labels[None, :]) # reduce the counts of these labels in `batch_cnts` by 1 batch_cnts[ix] -= 1 # if the count of a positive label falls below 3 # remove it from `batch_labels` mask = np.logical_or( batch_cnts > 2, np.logical_and(batch_cnts == 2, batch_labels < 0)) batch_labels = batch_labels[mask] batch_cnts = batch_cnts[mask] tensors, network_structure, con_order = _batch_cont( t1, t2, tensors, network_structure, con_order, common_batch_labels, labels_t1, labels_t2, backend_obj) # in all other cases do a regular tensordot else: ind_sort = np.argsort(t1_cont) tensors.append( backend_obj.tensordot(t1, t2, axes=(tuple(t1_cont[ind_sort]), tuple(t2_cont[ind_sort])))) network_structure.append( np.append(np.delete(labels_t1, t1_cont), np.delete(labels_t2, t2_cont))) # remove contracted labels from con_order con_order = np.delete( con_order, np.intersect1d(con_order, common_labels, assume_unique=True, return_indices=True)[1]) # perform outer products and remaining batch contractions while len(tensors) > 1: unique_labels, label_cnts = np.unique( np.concatenate(network_structure), return_counts=True) batch_labels = unique_labels[np.logical_or( label_cnts > 2, np.logical_and(label_cnts == 2, unique_labels < 0))] t2 = tensors.pop() t1 = tensors.pop() labels_t2 = network_structure.pop() labels_t1 = network_structure.pop() # check if there are negative batch indices left # (have to be collapsed to a single one) common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1, labels_t2, assume_unique=True, return_indices=True) common_batch_labels = np.intersect1d(batch_labels, common_labels, assume_unique=True) if common_batch_labels.shape[0] > 0: # collapse all negative batch indices tensors, network_structure, con_order = _batch_cont( t1, t2, tensors, network_structure, con_order, common_batch_labels, labels_t1, labels_t2, backend_obj) else: tensors.append(backend_obj.outer_product(t1, t2)) network_structure.append(np.append(labels_t1, labels_t2)) # if necessary do a final permutation if len(network_structure[0]) > 0: i1, i2 = np.nonzero(out_order[:, None] == network_structure[0][None, :]) return backend_obj.transpose(tensors[0], tuple(i1[i2])) return tensors[0]
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int], sizes: Tuple[int], con_order: Tuple[int], out_order: Tuple[int], backend_obj: AbstractBackend) -> Tensor: """ Jittable Ncon function. Performs the contraction of `tensors`. Args: tensors: List of tensors. flat_labels: A Tuple of integers. sizes: Tuple of int used to reconstruct `network_structure` from `flat_labels`. con_order: Order of the contraction. out_order: Order of the final axis order. backend_obj: A backend object. Returns: The final tensor after contraction. """ # some jax-juggling to avoid retracing ... flat_labels = list(flat_labels) slices = np.append(0, np.cumsum(sizes)) network_structure = [ flat_labels[slices[n]:slices[n + 1]] for n in range(len(slices) - 1) ] out_order = list(out_order) con_order = list(con_order) # pylint: disable=unnecessary-comprehension init_con_order = [c for c in con_order] init_network_structure = [c for c in network_structure] # partial trace for n, tensor in enumerate(tensors): tensors[n], network_structure[n], contracted_labels = _partial_trace( tensor, network_structure[n], backend_obj) if len(contracted_labels) > 0: con_order = [c for c in con_order if c not in contracted_labels] flat_labels = [l for sublist in network_structure for l in sublist] # contracted all positive labels appearing only once in `network_structure` contractable_labels = [ l for l in flat_labels if (flat_labels.count(l) == 1) and (l > 0) ] # update con_order if len(contractable_labels) > 0: con_order = [o for o in con_order if o not in contractable_labels] # collapse axes of single-labelled tensors locs = [] for n, labels in enumerate(network_structure): if len(set(labels).intersection(contractable_labels)) > 0: locs.append(n) for loc in locs: labels = network_structure[loc] contractable_inds = [labels.index(l) for l in contractable_labels] network_structure[loc] = [ l for l in labels if l not in contractable_labels ] tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds)) # perform binary and batch contractions skip_counter = 0 batch_labels = [] batch_cnts = [] for l in set(flat_labels): cnt = flat_labels.count(l) if (cnt > 2) or (cnt == 2 and l < 0): batch_labels.append(l) batch_cnts.append(cnt) while len(con_order) > 0: # the next index to be contracted cont_ind = con_order[0] if cont_ind in batch_labels: # if its still a batch index then do it later con_order.append(con_order.pop(0)) skip_counter += 1 # avoid being stuck in an infinite loop if skip_counter > len(con_order): raise ValueError( f"ncon seems stuck in an infinite loop. \n" f"Please check if `con_order` = {init_con_order} is " f"a valid contraction order for \n" f"`network_structure` = {init_network_structure}") continue # find locations of `cont_ind` in `network_structure` locs = [ n for n, labels in enumerate(network_structure) if cont_ind in labels ] t2 = tensors.pop(locs[1]) t1 = tensors.pop(locs[0]) labels_t2 = network_structure.pop(locs[1]) labels_t1 = network_structure.pop(locs[0]) common_labels, t1_cont, t2_cont = label_intersection( labels_t1, labels_t2) # check if there are batch labels (i.e. labels appearing more than twice # in `network_structure`). common_batch_labels = set(batch_labels).intersection(common_labels) if len(common_batch_labels) > 0: # case1: both tensors have one or more common batch indices -> use matmul ix = np.nonzero( np.array(batch_labels)[:, None] == np.array( list(common_batch_labels))[None, :])[0] # reduce the counts of these labels in `batch_cnts` by 1 delete = [] for i in ix: batch_cnts[i] -= 1 if (batch_labels[i] > 0) and (batch_cnts[i] <= 2): delete.append(i) elif (batch_labels[i] < 0) and (batch_cnts[i] < 2): delete.append(i) for i in sorted(delete, reverse=True): del batch_cnts[i] del batch_labels[i] tensors, network_structure, con_order = _batch_cont( t1, t2, tensors, network_structure, con_order, common_batch_labels, labels_t1, labels_t2, backend_obj) # in all other cases do a regular tensordot else: # for len(t1_cont)~<20 this is faster than np.argsort ind_sort = [t1_cont.index(l) for l in sorted(t1_cont)] tensors.append( backend_obj.tensordot( t1, t2, axes=(tuple([t1_cont[i] for i in ind_sort]), tuple([t2_cont[i] for i in ind_sort])))) new_labels = [l for l in labels_t1 if l not in common_labels ] + [l for l in labels_t2 if l not in common_labels] network_structure.append(new_labels) # remove contracted labels from con_order con_order = [c for c in con_order if c not in common_labels] # perform outer products and remaining batch contractions while len(tensors) > 1: t2 = tensors.pop() t1 = tensors.pop() labels_t2 = network_structure.pop() labels_t1 = network_structure.pop() # check if there are negative batch indices left # (have to be collapsed to a single one) common_labels, t1_cont, t2_cont = label_intersection( labels_t1, labels_t2) common_batch_labels = set(batch_labels).intersection(common_labels) if len(common_batch_labels) > 0: # collapse all negative batch indices tensors, network_structure, con_order = _batch_cont( t1, t2, tensors, network_structure, con_order, common_batch_labels, labels_t1, labels_t2, backend_obj) else: tensors.append(backend_obj.outer_product(t1, t2)) network_structure.append(labels_t1 + labels_t2) # if necessary do a final permutation if len(network_structure[0]) > 1: labels = network_structure[0] final_order = tuple([labels.index(l) for l in out_order]) return backend_obj.transpose(tensors[0], final_order) return tensors[0]
def _jittable_ncon(tensors: Sequence[Tensor], flat_labels: Tuple[int], sizes: Tuple[int], con_order: Tuple[int], out_order: Tuple[int], backend_obj: AbstractBackend) -> Any: """Jittable Ncon function. Args: tensors: List of tensors. network_structure: List of list of integers that descripes the network structure. con_order: Order of the contraction. out_order: Order of the final axis order. backend_obj: A backend object. Returns: The final tensor after contraction. """ # some jax-juggling to avoid retracing ... slices = np.append(0, np.cumsum(sizes)) network_structure = [ np.array(flat_labels)[slices[n]:slices[n + 1]] for n in range(len(slices) - 1) ] con_order = np.array(con_order) out_order = np.array(out_order) # now we're ready to do stuff if not isinstance(tensors, list): raise ValueError("`tensors` is not a list") # partial trace for n, tensor in enumerate(tensors): tensors[n], network_structure[n], contracted_labels = _partial_trace( tensor, network_structure[n], backend_obj) con_order = np.delete( con_order, np.intersect1d(con_order, contracted_labels, return_indices=True)[1]) # binary contractions while len(con_order) > 0: cont_ind = con_order[0] # the next index to be contracted locs = np.sort( np.nonzero([ np.isin(cont_ind, labels) for labels in network_structure ])[0]) t2 = tensors.pop(locs[1]) t1 = tensors.pop(locs[0]) labels_t2 = network_structure.pop(locs[1]) labels_t1 = network_structure.pop(locs[0]) common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1, labels_t2, assume_unique=True, return_indices=True) ind_sort = np.argsort(t1_cont) tensors.append( backend_obj.tensordot(t1, t2, axes=(tuple(t1_cont[ind_sort]), tuple(t2_cont[ind_sort])))) network_structure.append( np.append(np.delete(labels_t1, t1_cont), np.delete(labels_t2, t2_cont))) # remove contracted labels from con_order con_order = np.delete( con_order, np.intersect1d(con_order, common_labels, assume_unique=True, return_indices=True)[1]) # outer products while len(tensors) > 1: t2 = tensors.pop() t1 = tensors.pop() labels_t2 = network_structure.pop() labels_t1 = network_structure.pop() tensors.append(backend_obj.outer_product(t1, t2)) network_structure.append(np.append(labels_t1, labels_t2)) # final permutation if len(network_structure[0]) > 0: i1, i2 = np.nonzero(out_order[:, None] == network_structure[0][None, :]) return backend_obj.transpose(tensors[0], tuple(i1[i2])) return tensors[0]