コード例 #1
0
ファイル: basicOperations.py プロジェクト: ZENG-Hui/DMRG_py
def contractDiag(node: tn.Node, diag: np.array, edgeNum: int):
    node.tensor = np.transpose(
        node.tensor,
        [edgeNum] + [i for i in range(len(node.edges)) if i != edgeNum])
    for i in range(node[0].dimension):
        node.tensor[i] *= diag[i]
    node.tensor = np.transpose(
        node.tensor,
        list(range(1, edgeNum + 1)) + [0] +
        list(range(edgeNum + 1, len(node.edges))))
    return node
コード例 #2
0
def pairwise_reduction(net: BatchTensorNetwork, node: tensornetwork.Node,
                       edge: tensornetwork.Edge) -> tensornetwork.Node:
    """Parallel contraction of matrix chains.
  
  The operation performed by this function is described in Fig. 4 of the paper
  `TensorNetwork for Machine Learning`. It leads to a more efficient
  implementation of the MPS classifier both in terms of predictions and
  automatic gradient calculation. The idea is that the whole MPS side is saved
  in memory as one node that carries an artificial "space" edge. This function
  removes this additional index by performing the pairwise contractions as
  shown in the Figure.
  
  Args:
    net: TensorNetwork that contains the node we want to reduce.
    node: Node to reduce pairwise. The corresponding tensor should have the
      form (..., space edge, ..., a, b) and matrix multiplications will be
      performed over the last two indices using matmul.
    edge: Space edge of the node.
  
  Returns:
    node: Node after the reduction. Has the shape of given node with the `edge`
      removed.
  """
    # NOTE: This method could be included in the BatchedTensorNetwork class
    # however it seems better to be separated because (at least with the current
    # implementation) it performs a very specialized/non-general operation.
    # It also uses tf.matmul which restricts the backend, however this can be
    # easily generalized since all the backends support batched matmul.
    if not edge.is_dangling():
        raise ValueError("Cannot reduce non-dangling edge '{}'".format(edge))
    if edge.node1 is not node:
        raise ValueError("Edge '{}' does not belong to node '{}'".format(
            edge, node))

    tensor = node.tensor
    size = int(tensor.shape[edge.axis1])

    # Bring reduction edge in first position
    edge_order = list(range(len(list(tensor.shape))))
    edge_order[0] = edge.axis1
    edge_order[edge.axis1] = 0
    tensor = net.backend.transpose(tensor, edge_order)

    # Remove edge to be reduced from node
    node.edges.pop(edge.axis1)
    for e in node.edges[edge.axis1:]:
        if e.node1 is e.node2:
            raise NotImplementedError("Cannot binary reduce node "
                                      "'{}' with trace edge '{}'".format(
                                          node, e))
        if e.node1 is node:
            e.axis1 -= 1
        else:
            e.axis2 -= 1

    # Idea from this implementation is from jemisjoky/TorchMPS
    while size > 1:
        half_size = size // 2
        nice_size = 2 * half_size
        leftover = tensor[nice_size:]
        tensor = tf.matmul(tensor[0:nice_size:2], tensor[1:nice_size:2])
        tensor = net.backend.concat([tensor, leftover], axis=0)
        size = half_size + int(size % 2 == 1)

    node.tensor = tensor[0]
    return node
コード例 #3
0
def svdTruncation(node: tn.Node,
                  leftEdges: List[int],
                  rightEdges: List[int],
                  dir: str,
                  maxBondDim=128,
                  leftName='U',
                  rightName='V',
                  edgeName='default',
                  normalize=False,
                  maxTrunc=0):
    # np.seterr(all='raise')
    maxBondDim = getAppropriateMaxBondDim(maxBondDim,
                                          [node.edges[e] for e in leftEdges],
                                          [node.edges[e] for e in rightEdges])
    if dir == '>>':
        leftEdgeName = edgeName
        rightEdgeName = None
    else:
        leftEdgeName = None
        rightEdgeName = edgeName
    try:
        [U, S, V, truncErr
         ] = tn.split_node_full_svd(node, [node.edges[e] for e in leftEdges],
                                    [node.edges[e] for e in rightEdges],
                                    max_singular_values=maxBondDim,
                                    left_name=leftName,
                                    right_name=rightName,
                                    left_edge_name=leftEdgeName,
                                    right_edge_name=rightEdgeName)

    except np.linalg.LinAlgError:
        # TODO
        b = 1
        node.tensor = np.round(node.tensor, 16)
        [U, S, V, truncErr
         ] = tn.split_node_full_svd(node, [node.edges[e] for e in leftEdges],
                                    [node.edges[e] for e in rightEdges],
                                    max_singular_values=maxBondDim,
                                    left_name=leftName,
                                    right_name=rightName,
                                    left_edge_name=leftEdgeName,
                                    right_edge_name=rightEdgeName)
    s = S
    S = tn.Node(np.diag(S.tensor))
    tn.remove_node(s)
    norm = np.sqrt(sum(S.tensor**2))
    if norm == 0:
        b = 1
    if maxTrunc > 0:
        meaningful = sum(np.round(S.tensor / norm, maxTrunc) > 0)
        S.tensor = S.tensor[:meaningful]
        U.tensor = np.transpose(np.transpose(U.tensor)[:meaningful])
        V.tensor = V.tensor[:meaningful]
    if normalize:
        S = multNode(S, 1 / norm)
    for e in S.edges:
        e.name = edgeName
    if dir == '>>':
        l = copyState([U])[0]
        r = multiContraction(S,
                             V,
                             '1',
                             '0',
                             cleanOr1=True,
                             cleanOr2=True,
                             isDiag1=True)
    elif dir == '<<':
        l = multiContraction(U,
                             S, [len(U.edges) - 1],
                             '0',
                             cleanOr1=True,
                             cleanOr2=True,
                             isDiag2=True)
        r = copyState([V])[0]
    elif dir == '>*<':
        v = V
        V = copyState([V])[0]
        tn.remove_node(v)
        u = U
        U = copyState([U])[0]
        tn.remove_node(u)
        return [U, S, V, truncErr]

    tn.remove_node(U)
    tn.remove_node(S)
    tn.remove_node(V)
    return [l, r, truncErr]