示例#1
0
文件: main.py 项目: xitorch/xitorch
def get_loss(pos0, vel0, ts, pos_target):
    y0 = torch.cat((pos0.unsqueeze(0), vel0.unsqueeze(0)), dim=0)
    yt = solve_ivp(dydt, ts, y0, method="rk4")
    posf = yt[-1,0] # (nbatch, nparticles, ndim)
    dev = posf - pos_target
    loss = torch.dot(dev.reshape(-1), dev.reshape(-1))
    return loss, yt
示例#2
0
 def getoutput(a, b, c, ts, y0):
     module = clss(a, b)
     yt = solve_ivp(module.forward,
                    ts,
                    y0,
                    params=(c, ),
                    fwd_options=fwd_options)
     return yt