def contractDiag(node: tn.Node, diag: np.array, edgeNum: int): node.tensor = np.transpose( node.tensor, [edgeNum] + [i for i in range(len(node.edges)) if i != edgeNum]) for i in range(node[0].dimension): node.tensor[i] *= diag[i] node.tensor = np.transpose( node.tensor, list(range(1, edgeNum + 1)) + [0] + list(range(edgeNum + 1, len(node.edges)))) return node
def pairwise_reduction(net: BatchTensorNetwork, node: tensornetwork.Node, edge: tensornetwork.Edge) -> tensornetwork.Node: """Parallel contraction of matrix chains. The operation performed by this function is described in Fig. 4 of the paper `TensorNetwork for Machine Learning`. It leads to a more efficient implementation of the MPS classifier both in terms of predictions and automatic gradient calculation. The idea is that the whole MPS side is saved in memory as one node that carries an artificial "space" edge. This function removes this additional index by performing the pairwise contractions as shown in the Figure. Args: net: TensorNetwork that contains the node we want to reduce. node: Node to reduce pairwise. The corresponding tensor should have the form (..., space edge, ..., a, b) and matrix multiplications will be performed over the last two indices using matmul. edge: Space edge of the node. Returns: node: Node after the reduction. Has the shape of given node with the `edge` removed. """ # NOTE: This method could be included in the BatchedTensorNetwork class # however it seems better to be separated because (at least with the current # implementation) it performs a very specialized/non-general operation. # It also uses tf.matmul which restricts the backend, however this can be # easily generalized since all the backends support batched matmul. if not edge.is_dangling(): raise ValueError("Cannot reduce non-dangling edge '{}'".format(edge)) if edge.node1 is not node: raise ValueError("Edge '{}' does not belong to node '{}'".format( edge, node)) tensor = node.tensor size = int(tensor.shape[edge.axis1]) # Bring reduction edge in first position edge_order = list(range(len(list(tensor.shape)))) edge_order[0] = edge.axis1 edge_order[edge.axis1] = 0 tensor = net.backend.transpose(tensor, edge_order) # Remove edge to be reduced from node node.edges.pop(edge.axis1) for e in node.edges[edge.axis1:]: if e.node1 is e.node2: raise NotImplementedError("Cannot binary reduce node " "'{}' with trace edge '{}'".format( node, e)) if e.node1 is node: e.axis1 -= 1 else: e.axis2 -= 1 # Idea from this implementation is from jemisjoky/TorchMPS while size > 1: half_size = size // 2 nice_size = 2 * half_size leftover = tensor[nice_size:] tensor = tf.matmul(tensor[0:nice_size:2], tensor[1:nice_size:2]) tensor = net.backend.concat([tensor, leftover], axis=0) size = half_size + int(size % 2 == 1) node.tensor = tensor[0] return node
def svdTruncation(node: tn.Node, leftEdges: List[int], rightEdges: List[int], dir: str, maxBondDim=128, leftName='U', rightName='V', edgeName='default', normalize=False, maxTrunc=0): # np.seterr(all='raise') maxBondDim = getAppropriateMaxBondDim(maxBondDim, [node.edges[e] for e in leftEdges], [node.edges[e] for e in rightEdges]) if dir == '>>': leftEdgeName = edgeName rightEdgeName = None else: leftEdgeName = None rightEdgeName = edgeName try: [U, S, V, truncErr ] = tn.split_node_full_svd(node, [node.edges[e] for e in leftEdges], [node.edges[e] for e in rightEdges], max_singular_values=maxBondDim, left_name=leftName, right_name=rightName, left_edge_name=leftEdgeName, right_edge_name=rightEdgeName) except np.linalg.LinAlgError: # TODO b = 1 node.tensor = np.round(node.tensor, 16) [U, S, V, truncErr ] = tn.split_node_full_svd(node, [node.edges[e] for e in leftEdges], [node.edges[e] for e in rightEdges], max_singular_values=maxBondDim, left_name=leftName, right_name=rightName, left_edge_name=leftEdgeName, right_edge_name=rightEdgeName) s = S S = tn.Node(np.diag(S.tensor)) tn.remove_node(s) norm = np.sqrt(sum(S.tensor**2)) if norm == 0: b = 1 if maxTrunc > 0: meaningful = sum(np.round(S.tensor / norm, maxTrunc) > 0) S.tensor = S.tensor[:meaningful] U.tensor = np.transpose(np.transpose(U.tensor)[:meaningful]) V.tensor = V.tensor[:meaningful] if normalize: S = multNode(S, 1 / norm) for e in S.edges: e.name = edgeName if dir == '>>': l = copyState([U])[0] r = multiContraction(S, V, '1', '0', cleanOr1=True, cleanOr2=True, isDiag1=True) elif dir == '<<': l = multiContraction(U, S, [len(U.edges) - 1], '0', cleanOr1=True, cleanOr2=True, isDiag2=True) r = copyState([V])[0] elif dir == '>*<': v = V V = copyState([V])[0] tn.remove_node(v) u = U U = copyState([U])[0] tn.remove_node(u) return [U, S, V, truncErr] tn.remove_node(U) tn.remove_node(S) tn.remove_node(V) return [l, r, truncErr]