Beispiel #1
0
def _get_mps(hgen_l, hgen_r, phi, direction, labels):
    '''Combining hgen_l and hgen_r to get the matrix product state.'''
    NL, NR = hgen_l.N, hgen_r.N
    phi = tensor.Tensor(phi, labels=['al', 'sl+1', 'al+2', 'sl+2'])  #l=NL-1
    if direction == '->':
        A = hgen_l.evolutor.A(NL - 1, dense=True)  #get A[sNL](NL-1,NL)
        A = tensor.Tensor(A, labels=['sl+1', 'al', 'al+1\''])
        phi = tensor.contract([A, phi])
        phi = phi.chorder([0, 2, 1])  #now we get phi(al+1,sl+2,al+2)
        #decouple phi into S*B, B is column-wise othorgonal
        U, S, V = svd(phi.reshape([phi.shape[0], -1]), full_matrices=False)
        U = tensor.Tensor(U, labels=['al+1\'', 'al+1'])
        A = (A * U)  #get A(al,sl+1,al+1)
        B = transpose(
            V.reshape([S.shape[0], phi.shape[1], phi.shape[2]]),
            axes=(1, 2, 0)
        )  #al+1,sl+2,al+2 -> sl+2,al+2,al+1, stored in column wise othorgonal format
    else:
        B = hgen_r.evolutor.A(NR - 1, dense=True)  #get B[sNR](NL+1,NL+2)
        B = tensor.Tensor(B, labels=['sl+2', 'al+2',
                                     'al+1\'']).conj()  #!the conjugate?
        phi = tensor.contract([phi, B])
        #decouple phi into A*S, A is row-wise othorgonal
        U, S, V = svd(phi.reshape([phi.shape[0] * phi.shape[1], -1]),
                      full_matrices=False)
        V = tensor.Tensor(V, labels=['al+1', 'al+1\''])
        B = (V * B).chorder([1, 2, 0]).conj(
        )  #al+1,sl+2,al+2 -> sl+2,al+2,al+1, for B is in transposed order by default.
        A = transpose(
            U.reshape([phi.shape[0], phi.shape[1], S.shape[0]]),
            axes=(1, 0, 2)
        )  #al,sl+1,al+1 -> sl+1,al,al+1, stored in column wise othorgonal format

    AL = hgen_l.evolutor.get_AL(dense=True)[:-1] + [A]
    BL = [B] + hgen_r.evolutor.get_AL(dense=True)[::-1][1:]

    AL = [transpose(ai, axes=(1, 0, 2)) for ai in AL]
    BL = [transpose(bi, axes=(1, 0, 2)).conj() for bi in BL]  #transpose
    mps = MPS(AL=AL,
              BL=BL,
              S=S,
              labels=labels,
              forder=range(NL) + range(NL, NL + NR)[::-1])
    return mps