Example #1
0
def psim(state, N=None, DIM=None, fixed=None, visualize=False, order=2, state0=None, grid=None):
    qm, qm_1, qm_2, pm, mum_1, mum_2 = tj.state_to_weinstein_darboux(state, N, DIM)
    qf = fixed[0]
    if order >= 1:
        qf_1 = fixed[1]
    if order >= 2:
        qf_2 = fixed[2]

    w = [1, 0.5, 0.2]
    # weighting between different order terms

    # value
    v0 = qm - qf
    m0 = w[0] * np.einsum("ia,ia", v0, v0)  # 1./N ??
    if order >= 1:
        v1 = qm_1 - qf_1
        m1 = w[1] * np.einsum("iab,iab", v1, v1)  # 1./N ??
    if order >= 2:
        v2 = qm_2 - qf_2
        m2 = w[2] * np.einsum("iabg,iabg", v2, v2)  # 1./N ??

    # gradient
    dq0 = w[0] * 2.0 * v0  # 1./N ??
    if order >= 1:
        dq1 = w[1] * 2.0 * v1  # 1./N ??
    if order >= 2:
        dq2 = w[2] * 2.0 * v2  # 1./N ??

    # print "point sim: m0 " + str(m0) + ", m1 " + str(m1) + ", m2 " + str(m2)

    ## visualization
    if visualize:
        plt.figure(1)
        plt.clf()
        plt.plot(qf[:, 0], qf[:, 1], "bo")
        plt.plot(qm[:, 0], qm[:, 1], "rx")

        # grid
        if state0 != None and grid != None:
            (reggrid, Nx, Ny) = grid
            (_, _, mgridts) = tj.integrate(state0, pts=reggrid)
            mgridT = mgridts[-1:].reshape(-1, DIM)
            pg.plotGrid(mgridT, Nx, Ny)

        # generate vertices of a circle
        N_vert = 20
        circle_verts = np.zeros([2, N_vert + 1])
        theta = np.linspace(0, 2 * np.pi, N_vert)
        circle_verts[0, 0:N_vert] = 0.2 * np.cos(theta)
        circle_verts[1, 0:N_vert] = 0.2 * np.sin(theta)
        verts = np.zeros([2, N_vert + 1])
        units = np.ones(N_vert + 1)

        for i in range(0, len(qm)):
            plt.arrow(
                qm[i, 0], qm[i, 1], 0.2 * pm[i, 0], 0.2 * pm[i, 1], head_width=0.2, head_length=0.2, fc="b", ec="b"
            )
            if qm_1 != None:
                verts = np.dot(qm_1[i, :, :], circle_verts) + np.outer(qm[i, :], units)
                plt.plot(verts[0], verts[1], "r-")

        border = 0.4
        plt.xlim(min(np.vstack((qf, qm))[:, 0]) - border, max(np.vstack((qf, qm))[:, 0]) + border)
        plt.ylim(min(np.vstack((qf, qm))[:, 1]) - border, max(np.vstack((qf, qm))[:, 1]) + border)
        plt.axis("equal")
        plt.draw()

    if order == 0:
        return (m0, (dq0,))
    elif order == 1:
        return (m0 + m1, (dq0, dq1))
    else:
        return (m0 + m1 + m2, (dq0, dq1, dq2))
Example #2
0
def imsim( state, N=None, imshape=None, DIM=None, h=None, imms=None, Dimms=None, imf=None, imm=None, simfs=None, sDimfs=None, sgrid=None, visualize=False, state0=None, grid=None, order=None, imgrid=None, hscaling=None, SIGMA=None, imfs=None):
    q,q_1,q_2,p,mu_1,mu_2 = tj.state_to_weinstein_darboux( state,N,DIM )

    sampleq = partial(sample,d2unzip(q,N), hscaling=hscaling)
    simms = sampleq(imms)
    #sDimms = [apply_2d_slices(sampleq, Dimms[i]) for i in range(np.shape(Dimms)[0])]
    sDimms = [apply_2d_slices(partial(sampleq, imms), derivs[i]) for i in range(len(derivs))]

    d = DIM
    delta = np.identity(DIM)
    one = np.ones([DIM])
    one_minus_delta = np.ones([DIM,DIM])-np.eye(DIM) 

    # value
    v0 = simfs-simms
    m0 = (h**d)*np.einsum('i,i',v0,v0)
    if order >= 1:
        v1 = sDimfs[0]-np.einsum('bi,iba->ai',sDimms[0],q_1)
        m1 = (h**(d+2))/12*np.einsum('ai,ai',v1,v1)
    if order >= 2:
        G = sDimfs[1] \
                -np.einsum('dci,idb,ica->abi',sDimms[1],q_1,q_1) \
                -np.einsum('ci,icab->abi',sDimms[0],q_2)
        m2 = (h**(d+2))/12*np.einsum('i,aai->',v0,G) \
                + (h**(d+4))/(5*2**6)*np.einsum('aai,aai->',G,G) \
                + (h**(d+4))/(9*2**6)*np.einsum('ab,aai,bbi->',one_minus_delta,G,G) \
                + (h**(d+4))/(9*2**6)*np.einsum('ab,abi,abi->',one_minus_delta,G,G)

    # debug output
    if order >= 0:
        logging.info("m0: " + str(m0))
    if order >= 1:
        logging.info("m1: " + str(m1))
        #logging.info("sDimfs[0]: " + str(sDimfs[0]))
        #logging.info("moving: " + str(np.einsum('bi,iba->ai',sDimms[0],q_1)))
    if order >= 2:
        logging.info("m2: " + str(m2))
        #logging.info("G: " + str(G))
        #logging.info("sDimfs[1]: " + str(sDimfs[1]))
        #logging.info("moving: " + str(np.einsum('dci,idb,ica->abi',sDimms[1],q_1,q_1)+np.einsum('ci,icab->abi',sDimms[0],q_2)))
    
    # gradient
    # dq0
    g00 = -2*(h**d)*np.einsum('i,ai->ia',v0,sDimms[0])
    dq0 = g00
    if order >= 1:
        g01 = -(h**(d+2))/6*np.einsum('bai,ibe,ec,ci->ia',sDimms[1],q_1,delta,v1)
        dq0 = dq0+g01
    if order >= 2:
        g02 = -(h**(d+2))/12*np.einsum('ai,ddi->ia',sDimms[0],G)
        G1 = -np.einsum('bcai,ibe,icd->deai',sDimms[2],q_1,q_1) \
                -np.einsum('cai,icde->deai',sDimms[1],q_2)
        g03 = (h**(d+2))/12*np.einsum('i,ddai->ia',v0,G1)
        g04 = (h**(d+4))/(5*2**5)*np.einsum('ddi,ddai->ia',G,G1) \
                +(h**(d+4))/(9*2**5)*np.einsum('de,ddi,eeai->ia',one_minus_delta,G,G1) \
                +(h**(d+4))/(9*2**5)*np.einsum('de,dei,deai->ia',one_minus_delta,G,G1)
        dq0 = dq0+g02+g03+g04

    # rescale
    dq0 = hscaling*dq0

    # dq1
    if order >= 1:
        g10 = -(h**(d+2))/6*np.einsum('ai,bi->iab',sDimms[0],v1)
        dq1 = g10
        if order >= 2:
            G2 = -np.einsum('aci,ice,db->deabi',sDimms[1],q_1,delta) \
                    -np.einsum('aci,icd,eb->deabi',sDimms[1],q_1,delta)
            g11 = (h**(d+2))/12*np.einsum('i,ddabi->iab',v0,G2)
            g12 = (h**(d+4))/(5*2**5)*np.einsum('ddi,ddabi->iab',G,G2) \
                    +(h**(d+4))/(9*2**5)*np.einsum('de,ddi,eeabi->iab',one_minus_delta,G,G2) \
                    +(h**(d+4))/(9*2**5)*np.einsum('de,dei,deabi->iab',one_minus_delta,G,G2)
            dq1 = dq1+g11+g12

    # dq2
    if order >= 2:
        G3 = -np.einsum('bd,ce,ai->deabci',delta,delta,sDimms[0])
        g21 = (h**(d+2))/12*np.einsum('i,bc,bcabci->iabc',v0,delta,G3)
        g22 = (h**(d+4))/(5*2**5)*np.einsum('ddi,ddabci->iabc',G,G3) \
                +(h**(d+4))/(9*2**5)*np.einsum('de,ddi,eeabci->iabc',one_minus_delta,G,G3) \
                +(h**(d+4))/(9*2**5)*np.einsum('de,dei,deabci->iabc',one_minus_delta,G,G3)
        dq2 = g21+g22

    # visualization
    if visualize:
        logging.info("iteration visualization output")

        x = np.arange(imshape[0])
        y = np.arange(imshape[1])

        plt.figure(1)
        plt.clf()
        scimms = samplecross((x,y),imms).reshape(imshape)
        cmin = np.min([np.min(simms),np.min(simfs),np.min(scimms)])
        cmax = np.max([np.max(simms),np.max(simfs),np.max(scimms)])
        plt.imshow(scimms.T,vmin=cmin,vmax=cmax)
        plt.plot(d2zip(sgrid)[:,0],d2zip(sgrid)[:,1],'bo')
        plt.plot(q[:,0],q[:,1],'rx')
        plt.gray()
        plt.colorbar()
        plotJacobians(q,q_1)
        #plt.quiver(q[:,0],q[:,1],dq0[:,0],dq0[:,1],color='y')
        plt.xlim(0,imshape[0])
        plt.ylim(0,imshape[1])
        #plt.quiver(q[:,0],q[:,1],g00[:,0],g00[:,1])

        plt.figure(2)
        plt.clf()
        plt.imshow(rse(simfs,N).T,vmin=cmin,vmax=cmax)
        plt.colorbar()
        plt.figure(3)
        plt.clf()
        plt.imshow(rse(simms,N).T,vmin=cmin,vmax=cmax)
        plt.colorbar()

        plt.figure(4)
        plt.clf()
        plt.imshow(rse(v0,N).T)
        plt.colorbar()

        plt.figure(5)
        plt.clf()
        plt.imshow(rse(sDimms[0][0,:],N).T)
        plt.colorbar()

        # grid plot
        if sgrid != None:
            qf = d2zip(sgrid)

            plt.figure(6)
            plt.clf()
            plt.plot(qf[:,0],qf[:,1],'bo')
            plt.plot(q[:,0],q[:,1],'rx')

            # grid
            if state0 != None and grid != None:
                (reggrid,Nx,Ny) = grid
                (_,_,mgridts) = tj.integrate(state0,pts=reggrid)
                mgridT = mgridts[-1:].reshape(-1,DIM)
                pg.plotGrid(mgridT,Nx,Ny)

            ## generate vertices of a circle
            #N_vert = 20
            #circle_verts = np.zeros( [ 2 , N_vert + 1 ] )
            #theta = np.linspace(0,2*np.pi, N_vert )
            #circle_verts[0,0:N_vert] = SIGMA*np.cos(theta)
            #circle_verts[1,0:N_vert] = SIGMA*np.sin(theta)
            #verts = np.zeros([2, N_vert + 1])
            #units = np.ones( N_vert + 1)

            #for i in range(0,len(q)):
            #    plt.arrow(q[i,0], q[i,1], 0.2*p[i,0], 0.2*p[i,1],\
            #            head_width=0.2, head_length=0.2,\
            #            fc='b', ec='b')
            #    if (q_1 != None):
            #        verts = np.dot(q_1[i,:,:], circle_verts ) \
            #                + np.outer(q[i,:],units)
            #        plt.plot(verts[0],verts[1],'r-')

            border = 0.4
            plt.xlim(min(np.vstack((qf,q))[:,0])-border,max(np.vstack((qf,q))[:,0])+border)
            plt.ylim(min(np.vstack((qf,q))[:,1])-border,max(np.vstack((qf,q))[:,1])+border)
            plt.axis('equal')

        # warped images
        if state0 != None and imgrid != None and imf != None and imm != None:
            # fixed image, interpolated
            plt.figure(20)
            plt.clf()
            simf = sample(d2unzip(imgrid),imf,hscaling=hscaling);
            plt.imshow(simf.reshape(sqrt(simf.shape[0]),sqrt(simf.shape[0])).T)
            plt.colorbar()
            # fixed image, interpolated
            plt.figure(21)
            plt.clf()
            simf = sample(d2unzip(imgrid),imfs,hscaling=hscaling);
            plt.imshow(simf.reshape(sqrt(simf.shape[0]),sqrt(simf.shape[0])).T)
            plt.colorbar()

            # moving image, interpolated without transformation
            plt.figure(22)
            plt.clf()
            simf = sample(d2unzip(imgrid),imm,hscaling=hscaling);
            plt.imshow(simf.reshape(sqrt(simf.shape[0]),sqrt(simf.shape[0])).T)
            plt.colorbar()
            # moving image, interpolated without transformation
            plt.figure(23)
            plt.clf()
            simf = sample(d2unzip(imgrid),imms,hscaling=hscaling);
            plt.imshow(simf.reshape(sqrt(simf.shape[0]),sqrt(simf.shape[0])).T)
            plt.colorbar()

            # moving image, interpolated
            plt.figure(24)
            plt.clf()
            (_,_,mimgridts) = tj.integrate(state0,pts=imgrid)
            mimgridT = mimgridts[-1:].reshape(-1,DIM)
            simm = sample(d2unzip(mimgridT),imm,hscaling=hscaling);
            plt.imshow(simm.reshape(sqrt(simm.shape[0]),sqrt(simm.shape[0])).T)
            plt.colorbar()
            # moving image, interpolated
            plt.figure(25)
            plt.clf()
            simm = sample(d2unzip(mimgridT),imms,hscaling=hscaling);
            plt.imshow(simm.reshape(sqrt(simm.shape[0]),sqrt(simm.shape[0])).T)
            plt.colorbar()

        plt.draw()
        #plt.show(block=False)

        # save figures
        for i in plt.get_fignums():
            plt.figure(i)
            try:
                os.mkdir('output/%s' % os.getpid() )
            except:
                None
            plt.savefig('output/%s/figure%d.eps' % (os.getpid(),i) )

    if order == 0:
        return (m0, (dq0, ))
    elif order == 1:
        return (m0+m1, (dq0,dq1))
    else:
        return (m0+m1+m2, (dq0,dq1,dq2))
Example #3
0
def F(sim, nonmoving, x, weights=None, adjint=True, order=2, scalegrad=None, simGradCheck=False, energyGradCheck=False, visualize=False):
    """
    Function that scipy's optimize function will call for returning the value 
    and gradient for a given x. The forward and adjoint integration is called 
    from this function using the values supplied by the similarity measure.
    """

    N = sim['N']
    DIM = sim['DIM']

    i = 0
    q = np.reshape( nonmoving[i:(i+N*DIM)] , [N,DIM] )
    tj.gaussian.N = N
    tj.gaussian.DIM = DIM
    tj.gaussian.SIGMA = tj.SIGMA
    K,DK,D2K,D3K,D4K,D5K,D6K = tj.derivatives_of_kernel(q,q)

    # input
    state0 = np.append(nonmoving, x)
    if order < 1:
        state0 = np.append(state0, np.zeros(N*DIM**2)) # append mu_1
    if order < 2:
        state0 = np.append(state0, np.zeros(N*DIM*tj.triuDim())) # append mu_2
    # shift from triangular to symmetric
    state0 = tj.triangular_to_state(state0)
    triunonmoving = nonmoving
    triux = x
    nonmoving = state0[0:state0.size/2]
    x = state0[state0.size/2:]

    # rescale
    if scalegrad:
        #logging.debug("rescaling, SIGMA " + str(tj.SIGMA))
        q0,q0_1,q0_2,p0,mu0_1,mu0_2 = tj.state_to_weinstein_darboux( state0 )
        if order >= 1:
            mu0_1 = tj.SIGMA*mu0_1
        if order == 2:
            mu0_2 = tj.SIGMA*mu0_2
        state0 = tj.weinstein_darboux_to_state(q0,q0_1,q0_2,p0,mu0_1,mu0_2)

    q0,q0_1,q0_2,p0,mu0_1,mu0_2 = tj.state_to_weinstein_darboux( state0 )

    # flow
    (t_span, y_span) = tj.integrate(state0)
    stateT = y_span[-1]
    
    # debug
    qT,qT_1,qT_2,pT,muT_1,muT_2 = tj.state_to_weinstein_darboux( stateT )
    #logging.info("q0: " + str(q0))
    #logging.info("p0_2: " + str(p0))
    #logging.info("qT: " + str(qT))
    logging.info("||p0||: " + str(np.linalg.norm(p0)))
    logging.info("||mu0_1||: " + str(np.linalg.norm(mu0_1)))
    logging.info("||mu0_2||: " + str(np.linalg.norm(mu0_2)))
    #if order >= 1:
        #logging.info("q0_1: " + str(q0_1))
        #logging.info("qT_1: " + str(qT_1))
        #logging.info("mu0_1: " + str(mu0_1))
    #if order >= 2:
        #logging.info("q0_2: " + str(q0_2))
        #logging.info("qT_2: " + str(qT_2))
        #logging.info("mu0_2: " + str(mu0_2))
    #logging.info("qT-q0: " + str(qT-q0))
    #logging.info("qT_1-q0_1: " + str(qT_1-q0_1))
    #logging.info("qT_2-q0_2: " + str(qT_2-q0_2))

    simT = sim['f'](stateT, state0=state0, visualize=visualize)

    # debug
    #logging.info('match term (before flow/after flow/diff): ' + str(sim['f'](state0)[0]) + '/' + str(simT[0]) + '/' + str(sim['f'](state0)[0]-simT[0]))
    logging.info('match term after flow: ' + str(simT[0]))
    
    Ediff = tj.Hamiltonian(q0,p0,mu0_1,mu0_2) # path energy from Hamiltonian
    logging.info('Hamiltonian: ' + str(Ediff))

    if not adjint:
        return weights[1]*simT[0]+weights[0]*Ediff

    dq = simT[1][0]
    if order >= 1:
        dq_1 = simT[1][1]
    else:
        dq_1 = np.zeros(q0_1.shape)
    if order >= 2:
        dq_2 = simT[1][2]
    else:
        dq_2 = np.zeros(q0_2.shape)

    logging.info("||dq||: " + str(np.linalg.norm(dq)))
    logging.info("||dq_1||: " + str(np.linalg.norm(dq_1)))
    logging.info("||dq_2||: " + str(np.linalg.norm(dq_2)))
    ds1 = tj.weinstein_darboux_to_state(dq,dq_1,dq_2,np.zeros(dq.shape),np.zeros(dq_1.shape),np.zeros(dq_2.shape),N,DIM)

    if simGradCheck:
        logging.info("computing finite difference approximation of sim gradient")
        fsim = lambda x: sim['f'](np.hstack( (x,stateT[x.size:],) ), state0=state0)[0]
        findiffgrad = approx_fprime(stateT[0:N*DIM+N*DIM**2+N*DIM**3],fsim,1e-5)
        compgrad = ds1[0:N*DIM+N*DIM**2+N*DIM**3]
        graderr = np.max(abs(findiffgrad-compgrad))
        logging.debug("sim gradient numerical check error: %e",graderr)
        logging.debug("finite diff gradient: " + str(findiffgrad))
        logging.debug("computed gradient: " + str(compgrad))
        logging.debug("difference: " + str(findiffgrad-compgrad))
    if energyGradCheck:
        logging.info("computing finite difference approximation of energy gradient")
        fsim = lambda x: tj.Hamiltonian(q0,np.reshape(x[0:N*DIM],[N,DIM]),np.reshape(x[N*DIM:N*DIM+N*DIM**2],[N,DIM,DIM]),np.reshape(x[N*DIM+N*DIM**2:N*DIM+N*DIM**2+N*DIM**3],[N,DIM,DIM,DIM]))
        findiffgrad = approx_fprime(np.hstack((p0.flatten(),mu0_1.flatten(),mu0_2.flatten(),)),fsim,1e-7)
        compgrad = tj.grad_Hamiltonian(q0,p0,mu0_1,mu0_2)
        graderr = np.max(abs(findiffgrad-compgrad))
        logging.debug("energy gradient numerical check error: %e",graderr)
        logging.debug("finite diff gradient: " + str(findiffgrad))
        logging.debug("computed gradient: " + str(compgrad))
        logging.debug("difference: " + str(findiffgrad-compgrad))
    
    (t_span, y_span) = tj.adj_integrate(stateT,ds1)
    adjstate0 = y_span[-1]

    assert(nonmoving.size+x.size<=adjstate0.size/2)
    gradE = tj.grad_Hamiltonian(q0,p0,mu0_1,mu0_2)
    assert(adjstate0.size/2-nonmoving.size == gradE.size) # gradE doesn't include point variations currently
    gradE = gradE[0:x.size]
    grad0 = weights[1]*adjstate0[adjstate0.size/2+nonmoving.size:adjstate0.size/2+nonmoving.size+x.size] + weights[0]*gradE # transported gradient + grad of energy

    adjstate0[adjstate0.size/2+nonmoving.size:adjstate0.size/2+nonmoving.size+grad0.size] = grad0
    grad0 = tj.state_to_triangular(adjstate0[adjstate0.size/2:adjstate0.size])[triunonmoving.size:triunonmoving.size+triux.size]

    grad0 = np.ndarray.flatten(grad0)

    # rescale
    if scalegrad:
        if order >= 1:
            grad0[N*DIM:N*DIM+N*DIM**2] = tj.SIGMA*grad0[N*DIM:N*DIM+N*DIM**2]
        if order == 2:
            grad0[N*DIM+N*DIM**2:N*DIM+N*DIM**2+N*DIM**3] = tj.SIGMA*grad0[N*DIM+N*DIM**2:N*DIM+N*DIM**2+N*DIM**3]

    # visualization
    dq0,dq0_1,dq0_2,dp0,dmu0_1,dmu0_2 = tj.state_to_weinstein_darboux( adjstate0[adjstate0.size/2:adjstate0.size],N,DIM )
    #logging.info("dp0: " + str(dp0))
    logging.info("||dp0|| final: " + str(np.linalg.norm(dp0)))
    #logging.info("dmu0_1: " + str(dmu0_1))
    logging.info("||dmu0_1|| final: " + str(np.linalg.norm(dmu0_1)))
    #logging.info("dmu0_2: " + str(dmu0_2))
    logging.info("||dmu0_2|| final: " + str(np.linalg.norm(dmu0_2)))
    #logging.info("adjstate0: " + str(adjstate0))
    #logging.info("grad0: " + str(grad0))
    #plt.figure(0)
    #plt.quiver(q0[:,0],q0[:,1],dp0[:,0],dp0[:,1])

    ## pause
    #raw_input("F: Press ENTER to continue")

    return (weights[1]*simT[0]+weights[0]*Ediff, grad0)
Example #4
0
	for i in range(0,len(q)):
		plt.arrow(q[i,0], q[i,1], 0.2*p[i,0], 0.2*p[i,1],\
				  head_width=0.2, head_length=0.2,\
				  fc='b', ec='b')
                if (q1 != None):
                        verts = np.dot(q1[i,:,:], circle_verts ) \
                                + np.outer(q[i,:],units)
                        print np.shape( verts )
                        print np.shape( q1 )
                        plt.plot(verts[0],verts[1],'b-')

        plt.axis([- W, W,- W, W ])
	return f

y_data = np.load('output/state_data.npy')
time_data = np.load('output/time_data.npy')

#print 'shape of y_data is ' + str( y_data.shape )
N_timestep = y_data.shape[0]
print 'generating png files'
for k in range(0,N_timestep):
	q,q_1,q_2,p,mu_1,mu_2 = tj.state_to_weinstein_darboux( y_data[k] )
	f = display_velocity_field(q,p,mu_1,mu_2,q_1)
	time_s = str(time_data[k])
	plt.suptitle('t = '+ time_s[0:4] , fontsize=16 , x = 0.75 , y = 0.25 )
	fname = './movie_frames/frame_'+str(k)+'.png'
	f.savefig( fname )
	plt.close(f)
print 'done'
Example #5
0
def Fdisp(sim, nonmoving, x, adjint=True, weights=None, order=2, scalegrad=None):
    """
    As F above but using linear displacement instead of integrating the flows. For testing.
    """

    # input
    N = sim['N'] # debug
    DIM = sim['DIM'] # debug
    tj.gaussian.N = N
    tj.gaussian.DIM = DIM
    tj.gaussian.SIGMA = tj.SIGMA
    #x = np.append(x, np.zeros(N*DIM**3)) # debug
    state0 = np.append(nonmoving, x)
    if order < 1:
        state0 = np.append(state0, np.zeros(N*DIM**2)) # append mu_1
    if order < 2:
        state0 = np.append(state0, np.zeros(N*DIM**3)) # append mu_2

    # shift from triangular to symmetric
    #assert(np.allclose(state0,tj.state_to_triangular(tj.triangular_to_state(state0))))
    state0 = tj.triangular_to_state(state0)
    triunonmoving = nonmoving
    triux = x
    nonmoving = state0[0:state0.size/2]
    x = state0[state0.size/2:]

    # rescale
    if scalegrad:
        #logging.debug("rescaling, SIGMA " + str(tj.SIGMA))
        q0,q0_1,q0_2,p0,mu0_1,mu0_2 = tj.state_to_weinstein_darboux( state0 )
        if order >= 1:
            mu0_1 = tj.SIGMA*mu0_1
        if order == 2:
            mu0_2 = tj.SIGMA*mu0_2
        state0 = tj.weinstein_darboux_to_state(q0,q0_1,q0_2,p0,mu0_1,mu0_2)

    q0,q0_1,q0_2,p0,mu0_1,mu0_2 = tj.state_to_weinstein_darboux( state0 )

    # displacement
    qT = q0+p0
    qT_1 = q0_1+mu0_1
    qT_2 = q0_2+mu0_2
    stateT = tj.weinstein_darboux_to_state(qT,qT_1,qT_2,p0,mu0_1,mu0_2)
    
    # debug
    #qT,qT_1,qT_2,pT,muT_1,muT_2 = tj.state_to_weinstein_darboux( stateT )
    #logging.info("q0_1: " + str(q0_1))
    #logging.info("q0_2: " + str(q0_2))
    #logging.info("p0: " + str(p0))
    #logging.info("mu0_1: " + str(mu0_1))
    #logging.info("mu0_2: " + str(mu0_2))
    #logging.info("qT: " + str(qT))
    #logging.info("qT_1: " + str(qT_1))
    #logging.info("qT_2: " + str(qT_2))
    #logging.info("qT-q0: " + str(qT-q0))
    #logging.info("qT_1-q0_1: " + str(qT_1-q0_1))
    #logging.info("qT_2-q0_2: " + str(qT_2-q0_2))
    logging.info("||p0||: " + str(np.linalg.norm(p0)))
    logging.info("||mu0_1||: " + str(np.linalg.norm(mu0_1)))
    logging.info("||mu0_2||: " + str(np.linalg.norm(mu0_2)))

    simT = sim['f'](stateT)

    # debug
    logging.info('match term (after flow): ' + str(simT[0]))
    
    Ediff = tj.Hamiltonian(q0,p0,mu0_1,mu0_2) # path energy from Hamiltonian
    logging.info('Hamiltonian: ' + str(Ediff))

    if not adjint:
        return weights[1]*simT[0]+weights[0]*Ediff

    dq = simT[1][0]
    if order >= 1:
        dq_1 = simT[1][1]
    else:
        dq_1 = np.zeros([N,DIM,DIM])
    if order >= 2:
        dq_2 = simT[1][2]
    else:
        dq_2 = np.zeros([N,DIM,DIM,DIM])
    ds1 = tj.weinstein_darboux_to_state(np.zeros(dq.shape),np.zeros(dq_1.shape),np.zeros(dq_2.shape),dq,dq_1,dq_2,sim['N'],sim['DIM'])
    
    adjstate0 = np.append(np.zeros(ds1.size), ds1)
    
    assert(nonmoving.size+x.size==adjstate0.size/2)
    gradE = tj.grad_Hamiltonian(q0,p0,mu0_1,mu0_2)
    assert(adjstate0.size/2-nonmoving.size == gradE.size) # gradE doesn't include point variations currently
    grad0 = np.array(weights[1]*adjstate0[adjstate0.size/2+nonmoving.size:] + weights[0]*gradE) # transported gradient + grad of energy

    # get from symmetric to triangular form
    #assert(np.allclose(adjstate0[adjstate0.size/2:adjstate0.size],tj.triangular_to_state(tj.state_to_triangular(adjstate0[adjstate0.size/2:adjstate0.size]))))
    grad0 = tj.state_to_triangular(adjstate0[adjstate0.size/2:adjstate0.size])[triunonmoving.size:triunonmoving.size+triux.size]

    grad0 = np.ndarray.flatten(grad0)

    # rescale
    if scalegrad:
        if order >= 1:
            grad0[N*DIM:N*DIM+N*DIM**2] = tj.SIGMA*grad0[N*DIM:N*DIM+N*DIM**2]
        if order == 2:
            grad0[N*DIM+N*DIM**2:N*DIM+N*DIM**2+N*DIM**3] = tj.SIGMA*grad0[N*DIM+N*DIM**2:N*DIM+N*DIM**2+N*DIM**3]

    # debug
    #grad0 = grad0[0:N*DIM+N*DIM**2]
    logging.info("||dq||: " + str(np.linalg.norm(dq)))
    logging.info("||dq_1||: " + str(np.linalg.norm(dq_1)))
    logging.info("||dq_2||: " + str(np.linalg.norm(dq_2)))
    #logging.info("grad0: " + str(grad0))
    
    ## pause
    #raw_input("Fdisp: Press ENTER to continue")
    
    return (weights[1]*simT[0]+weights[0]*Ediff, grad0)