Пример #1
0
class Satellite(RigidBody):
    d = 3  # Cartesian Embedding dimension
    D = 9  #=3 euler coordinate dimension
    #angular_dims = range(3)
    n = 12

    def __init__(self, mass=1, l=1):
        self.body_graph = BodyGraph()
        self.body_graph.add_extended_nd(0, m=12, moments=(1 / 2, 1 / 2,
                                                          1 / 2))  #main body
        self.body_graph.add_extended_nd(1, m=3, moments=(1, 1, 1 / 8))
        self.body_graph.add_joint(0,
                                  torch.tensor([1., 0, 0]),
                                  1,
                                  torch.tensor([0., 0, 0]),
                                  rotation_axis=(torch.tensor([1., 0, 0]),
                                                 torch.tensor([0, 0, 1.])))
        self.body_graph.add_extended_nd(2, m=3, moments=(1, 1, 1 / 8))
        self.body_graph.add_joint(0,
                                  torch.tensor([0., -1, 0]),
                                  2,
                                  torch.tensor([0., 0, 0]),
                                  rotation_axis=(torch.tensor([0., -1, 0]),
                                                 torch.tensor([0, 0, 1.])))
        # self.body_graph.add_extended_nd(3,m=3,moments=(1,1,1/8))
        # self.body_graph.add_joint(2,torch.tensor([1.,1,1]),3,torch.tensor([0.,0,0]),
        #     rotation_axis=(torch.tensor([1.,1,1])/np.sqrt(3),torch.tensor([0,0,1.])))

    def sample_initial_conditions(self, N):
        bodyX = torch.randn(N, 2, self.n, self.d)
        return project_onto_constraints(self.body_graph, bodyX)

    def potential(self, x):
        """ Gravity potential """
        return 0
class Rotor(RigidBody):
    d = 3  # Cartesian Embedding dimension
    D = 6  #=3 euler + 3 com Total body coordinate dimension
    angular_dims = range(3, 6)  #slice(3,None)
    n = 4
    dt = 0.05
    integration_time = 5.

    def __init__(self, mass=.1, obj='rotor'):  #,moments=(1,2,3)):
        verts, tris = read_obj(obj + '.obj')

        _, com, covar = compute_moments(torch.from_numpy(verts[tris]))
        verts -= com.numpy()[None, :]  # set com as 0
        # verts*=100
        eigs, Q = np.linalg.eigh(covar.numpy())
        #print(compute_moments(torch.from_numpy((verts@Q)[tris])))
        moments = torch.diag(covar)
        self.obj = (verts, tris)
        self.body_graph = BodyGraph()
        self.body_graph.add_extended_nd(0, mass, moments, d=3)

    def sample_initial_conditions(self, N):
        comEulers = (2 * torch.randn(N, 2, 6)).clamp(max=3, min=-3)
        comEulers[:, :, :3] *= .1
        bodyX = comEuler2bodyX(comEulers)
        #bodyX = torch.randn(N,2,4,3)
        #bodyX = project_onto_constraints(self.body_graph,bodyX)
        return bodyX
        # comEulers = (.75*torch.randn(N,2,6)).clamp(max=1.5,min=-1.5)
        # comEulers[:,0,3:]*=.05
        # comEulers[:,1,3:]*=1
        # # comEulers[:,1,5]*=500
        # comEulers[:,:,:3]*=.005
        # comEulers[:,1,5]*=4
        # #comEulers[]
        # bodyX = comEuler2bodyX(comEulers)
        # #bodyX = torch.randn(N,2,4,3) + torch.randn(N,1,1,1)
        # #bodyX = project_onto_constraints(self.body_graph,bodyX)
        # return bodyX
        #return
    def potential(self, x):
        return 0.

    def body2globalCoords(self, comEulers):
        """ input: (bs,2,6) output: (bs,2,4,3) """
        return comEuler2bodyX(comEulers)

    def global2bodyCoords(self, bodyX):
        """ input: (bs,2,4,3) output: (bs,2,6)"""
        comEuler = bodyX2comEuler(bodyX)
        #unwrap euler angles for continuous trajectories
        unwrapped_angles = torch.from_numpy(
            np.unwrap(comEuler[:, 0, 3:], axis=0))
        comEuler[:, 0, 3:] = unwrapped_angles.to(bodyX.device, bodyX.dtype)
        return comEuler

    @property
    def animator(self):
        return RigidAnimation
 def __init__(self, mass=.1, obj='gyro'):
     verts,tris =  read_obj(obj+'.obj')
     verts[:,2] -= verts[:,2].min() # set bottom as 0
     _,com,covar = compute_moments(torch.from_numpy(verts[tris]))
     print(torch.diag(torch.diag(covar).sum()*torch.eye(3)-covar),torch.diag(covar))
     self.obj = (verts,tris,com.numpy())
     self.body_graph =  BodyGraph()
     self.body_graph.add_extended_nd(0,m=mass,moments=100*torch.diag(covar))
     self.body_graph.add_joint(0,-com,pos2=torch.tensor([0.,0.,0.]))
    def __init__(self, mass=.1, obj='rotor'):  #,moments=(1,2,3)):
        verts, tris = read_obj(obj + '.obj')

        _, com, covar = compute_moments(torch.from_numpy(verts[tris]))
        verts -= com.numpy()[None, :]  # set com as 0
        # verts*=100
        eigs, Q = np.linalg.eigh(covar.numpy())
        #print(compute_moments(torch.from_numpy((verts@Q)[tris])))
        moments = torch.diag(covar)
        self.obj = (verts, tris)
        self.body_graph = BodyGraph()
        self.body_graph.add_extended_nd(0, mass, moments, d=3)
Пример #5
0
 def __init__(self, bobs=2, m=1, l=1,k=10):
     self.body_graph = BodyGraph()#nx.Graph()
     self.arg_string = f"n{bobs}m{m or 'r'}l{l}"
     with FixedNumpySeed(0):
         ms = [.6+.8*np.random.rand() for _ in range(bobs)] if m is None else bobs*[m]
     self.ms = copy.deepcopy(ms)
     ls = bobs*[l]
     self.ks = torch.tensor((bobs-1)*[k]).float()
     self.locs = torch.zeros(bobs,3)
     self.locs[:,0] = 1*torch.arange(bobs).float()
     for i in range(bobs):
         self.body_graph.add_extended_nd(i, m=ms.pop(), d=0,tether=(self.locs[i],ls.pop()))
     self.n = bobs
     self.D = 2*self.n # Spherical coordinates, phi, theta per bob
     self.angular_dims = range(self.D)
Пример #6
0
 def __init__(self, mass=1, l=1):
     self.body_graph = BodyGraph()
     self.body_graph.add_extended_nd(0, m=12, moments=(1 / 2, 1 / 2,
                                                       1 / 2))  #main body
     self.body_graph.add_extended_nd(1, m=3, moments=(1, 1, 1 / 8))
     self.body_graph.add_joint(0,
                               torch.tensor([1., 0, 0]),
                               1,
                               torch.tensor([0., 0, 0]),
                               rotation_axis=(torch.tensor([1., 0, 0]),
                                              torch.tensor([0, 0, 1.])))
     self.body_graph.add_extended_nd(2, m=3, moments=(1, 1, 1 / 8))
     self.body_graph.add_joint(0,
                               torch.tensor([0., -1, 0]),
                               2,
                               torch.tensor([0., 0, 0]),
                               rotation_axis=(torch.tensor([0., -1, 0]),
                                              torch.tensor([0, 0, 1.])))
 def __init__(self, links=2, beams=False, m=None, l=None):
     self.body_graph = BodyGraph()  #nx.Graph()
     self.arg_string = f"n{links}{'b' if beams else ''}m{m or 'r'}l{l or 'r'}"
     assert not beams, "beams temporarily not supported"
     with FixedNumpySeed(0):
         ms = [.6 + .8 * np.random.rand()
               for _ in range(links)] if m is None else links * [m]
         ls = [.6 + .8 * np.random.rand()
               for _ in range(links)] if l is None else links * [l]
     self.ms = copy.deepcopy(ms)
     self.body_graph.add_extended_nd(0,
                                     m=ms.pop(),
                                     d=0,
                                     tether=(torch.zeros(2), ls.pop()))
     for i in range(1, links):
         self.body_graph.add_extended_nd(i, m=ms.pop(), d=0)
         self.body_graph.add_edge(i - 1, i, l=ls.pop())
     self.D = self.n = links
     self.angular_dims = range(links)
 def __init__(self, links=2, beams=False, m=1, l=1):
     self.body_graph = BodyGraph()  #nx.Graph()
     self.arg_string = f"n{links}{'b' if beams else ''}m{m}l{l}"
     beam_moments = torch.tensor([m * l * l / 12])
     if beams:
         self.body_graph.add_extended_nd(0, m=m, moments=beam_moments, d=1)
         self.body_graph.add_joint(0, torch.tensor([l / 2]))
         for i in range(1, links):
             self.body_graph.add_extended_nd(i,
                                             m=m,
                                             moments=beam_moments,
                                             d=1)
             self.body_graph.add_joint(i - 1, torch.tensor([-l / 2]), i,
                                       torch.tensor([l / 2]))
     else:
         self.body_graph.add_node(0, m=m, tether=torch.zeros(2), l=l)
         for i in range(1, links):
             self.body_graph.add_node(i, m=m)
             self.body_graph.add_edge(i - 1, i, l=l)
 def __init__(self, mass=3, l=1, q=.3, magnets=2):
     with FixedNumpySeed(0):
         mass = np.random.rand() * .8 + 2.4 if mass is None else mass
     self.ms = [mass]
     self.arg_string = f"m{mass or 'r'}l{l}q{q}mn{magnets}"
     self.body_graph = BodyGraph()
     self.body_graph.add_extended_nd(0,
                                     m=mass,
                                     d=0,
                                     tether=(torch.zeros(3), l))
     self.q = q  # magnetic moment magnitude
     theta = torch.linspace(0, 2 * np.pi, magnets + 1)[:-1]
     self.magnet_positions = torch.stack(
         [
             0.1 * theta.cos(), 0.1 * theta.sin(),
             -(1.05) * l * torch.ones_like(theta)
         ],
         dim=-1,
     )
     self.magnet_dipoles = q * torch.stack(
         [0 * theta, 0 * theta,
          torch.ones_like(theta)], dim=-1)  # +z direction
Пример #10
0
class CoupledPendulum(MagnetPendulum):
    d=3
    def __init__(self, bobs=2, m=1, l=1,k=10):
        self.body_graph = BodyGraph()#nx.Graph()
        self.arg_string = f"n{bobs}m{m or 'r'}l{l}"
        with FixedNumpySeed(0):
            ms = [.6+.8*np.random.rand() for _ in range(bobs)] if m is None else bobs*[m]
        self.ms = copy.deepcopy(ms)
        ls = bobs*[l]
        self.ks = torch.tensor((bobs-1)*[k]).float()
        self.locs = torch.zeros(bobs,3)
        self.locs[:,0] = 1*torch.arange(bobs).float()
        for i in range(bobs):
            self.body_graph.add_extended_nd(i, m=ms.pop(), d=0,tether=(self.locs[i],ls.pop()))
        self.n = bobs
        self.D = 2*self.n # Spherical coordinates, phi, theta per bob
        self.angular_dims = range(self.D)

    def sample_initial_conditions(self, bs):
        n = len(self.body_graph.nodes)
        angles_and_angvel = .3*torch.randn(bs, 2, 2*n)  # (bs,2,n)
        angles_and_angvel[:,0,1::2] += np.pi/2
        angles_and_angvel[:,0,::2] += np.pi
        z = self.body2globalCoords(angles_and_angvel) #(bs,2,n,d)
        #z[:,0] += self.locs.to(z.device,z.dtype)
        #z[:,0] += .2*torch.randn(bs,n,3)
        #z[:,1,-1] = 1.0*torch.randn(bs,3)
        #z[:,1] = .5*z[:,1] + .4*torch.randn(bs,n,3)
        try: return project_onto_constraints(self.body_graph,z,tol=1e-5)
        except OverflowError: return self.sample_initial_conditions(bs)
    
    # def sample_initial_conditions(self, bs):
    #     n = len(self.body_graph.nodes)
    #     angles_and_angvel = .5*torch.randn(bs, 2, 2*n)  # (bs,2,n)
    #     angles_and_angvel[:,0,:] += np.pi/2
    #     angles_and_angvel[:,0,::2] -= np.pi
    #     z = self.body2globalCoords(angles_and_angvel) #(bs,2,n,d)
    #     #z[:,0] += self.locs.to(z.device,z.dtype)
    #     #z[:,0] += .2*torch.randn(bs,n,3)
    #     z[:,1,-1] = 2*torch.randn(bs,3)
    #     #z[:,1] = .5*z[:,1] + .4*torch.randn(bs,n,3)
    #     try: return project_onto_constraints(self.body_graph,z,tol=1e-5)
    #     except OverflowError: return self.sample_initial_conditions(bs)
    def global2bodyCoords(self, global_pos_vel):
        """ input (bs,2,n,3) output (bs,2,dangular=2n) """
        xyz = copy.deepcopy(global_pos_vel)
        xyz[:,0] -= self.locs.to(xyz.device,xyz.dtype)
        return super().global2bodyCoords(xyz)
    def body2globalCoords(self, angles_omega):
        """ input (bs,2,dangular=2n) output (bs,2,n,3) """
        xyz = super().body2globalCoords(angles_omega)
        xyz[:,0]+=self.locs.to(xyz.device,xyz.dtype)
        return xyz # (bs,2,n,3)

    def potential(self, x):
        """inputs [x (bs,n,d)] Gravity potential
           outputs [V (bs,)] """
        gpe = 9.81*(self.M @ x)[..., 2].sum(1)
        l0s = ((self.locs[1:]-self.locs[:-1])**2).sum(-1).sqrt().to(x.device,x.dtype)
        xdist = ((x[:,1:,:]-x[:,:-1,:])**2).sum(-1).sqrt()
        spring_energy = (.5*self.ks.to(x.device,x.dtype)*(xdist-l0s)**2).sum(1)
        return gpe+spring_energy

    @property
    def animator(self):
        return CoupledPendulumAnimation
class MagnetPendulum(RigidBody):
    d = 3
    n = 1
    D = 2
    angular_dims = range(2)
    dt = 0.05
    integration_time = 5.

    def __init__(self, mass=3, l=1, q=.3, magnets=2):
        with FixedNumpySeed(0):
            mass = np.random.rand() * .8 + 2.4 if mass is None else mass
        self.ms = [mass]
        self.arg_string = f"m{mass or 'r'}l{l}q{q}mn{magnets}"
        self.body_graph = BodyGraph()
        self.body_graph.add_extended_nd(0,
                                        m=mass,
                                        d=0,
                                        tether=(torch.zeros(3), l))
        self.q = q  # magnetic moment magnitude
        theta = torch.linspace(0, 2 * np.pi, magnets + 1)[:-1]
        self.magnet_positions = torch.stack(
            [
                0.1 * theta.cos(), 0.1 * theta.sin(),
                -(1.05) * l * torch.ones_like(theta)
            ],
            dim=-1,
        )
        self.magnet_dipoles = q * torch.stack(
            [0 * theta, 0 * theta,
             torch.ones_like(theta)], dim=-1)  # +z direction
        # self.magnet_positions = torch.tensor([0.,0., -1.1*l])[None]
        # self.magnet_dipoles = q*torch.tensor([0.,0.,1.])[None]
    def sample_initial_conditions(self, N):
        # phi =torch.rand(N)*2*np.pi
        # phid = .1*torch.randn(N)
        # theta = (4/5)*np.pi + .1*torch.randn(N)
        # thetad = 0.00*torch.randn(N)
        angles_omega = torch.zeros(N, 2, 2)
        angles_omega[:, 0, 0] = np.pi + .3 * torch.randn(N)
        angles_omega[:, 1, 0] = .05 * torch.randn(N)
        angles_omega[:, 0, 1] = np.pi / 2 + .2 * torch.randn(N)
        angles_omega[:, 1, 1] = .4 * torch.randn(N)
        xv = self.body2globalCoords(angles_omega)
        return xv

    # def sample_initial_conditions(self,N):
    #     angles_vel = torch.randn(N,2,2,1)
    #     return self.body2globalCoords(angles_vel)

    def global2bodyCoords(self, global_pos_vel):
        """ input (bs,2,1,3) output (bs,2,dangular=2n) """
        bsT, _, n, d = global_pos_vel.shape
        x, y, z = global_pos_vel[:, 0, :, :].permute(2, 0, 1)
        xd, yd, zd = global_pos_vel[:, 1, :, :].permute(2, 0, 1)
        x, z, xd, zd = z, -x, zd, -xd  # Rotate coordinate system by 90 about y
        phi = torch.atan2(y, x)
        rz = (x**2 + y**2).sqrt()
        r = (rz**2 + z**2).sqrt()
        theta = torch.atan2(rz, z)
        phid = (x * yd - y * xd) / rz**2
        thetad = ((xd * x * z + yd * y * z) / rz - rz * zd) / r**2
        angles = torch.stack([phi, theta], dim=-1)
        angles = torch.from_numpy(np.unwrap(angles.numpy(),
                                            axis=0)).to(r.device, r.dtype)
        anglesd = torch.stack([phid, thetad], dim=-1)
        angles_omega = torch.stack([angles, anglesd], dim=1)
        return angles_omega.reshape(bsT, 2, 2 * n)

    def body2globalCoords(self, angles_omega):
        """ input (bs,2,dangular=2) output (bs,2,1,3) """
        bs, _, n2 = angles_omega.shape
        n = n2 // 2
        euler_angles = torch.zeros(n * bs,
                                   2,
                                   3,
                                   device=angles_omega.device,
                                   dtype=angles_omega.dtype)
        euler_angles[:, :, :2] = angles_omega.reshape(bs, 2, n, 2).permute(
            2, 0, 1, 3).reshape(n * bs, 2, 2)
        # To treat z axis of ZXZ euler angles as spherical coordinates
        # simply set (alpha,beta,gamma) = (phi+pi/2,theta,0)
        euler_angles[:, 0, 0] += np.pi / 2
        zhat_p = euler2frame(euler_angles)[:, :, 2]
        #zhat_p = -euler2frame(euler_angles)[:,:,0]
        zhat_p[:, :, [0, 2]] = zhat_p[:, :, [2, 0]]
        zhat_p[:, :, 0] *= -1  # rotate coordinates by -90 about y
        return zhat_p.reshape(n, bs, 2, 3).permute(1, 2, 0, 3)  # (bs,2,n,3)

    def potential(self, x):
        """ Gravity potential """
        gpe = 9.81 * (self.M @ x)[..., :, 2].sum(
            -1)  # (self.M @ x)[..., 2].sum(1)
        ri = self.magnet_positions.to(x.device, x.dtype)
        mi = self.magnet_dipoles[None].to(x.device, x.dtype)  # (1,magnets,d)
        r0 = x.squeeze(-2)  # (bs,1,d) -> (bs,d)
        m0 = (self.q * r0 / (r0**2).sum(-1, keepdims=True))[:, None]  #(bs,1,d)
        r0i = (ri[None] - r0[:, None])  # (bs,magnets,d)
        m0dotr0i = (m0 * r0i).sum(
            -1)  #([email protected](-1,-2)).squeeze(-1) # (bs,magnets)
        midotr0i = (mi * r0i).sum(-1)
        m0dotmi = (m0 * mi).sum(-1)
        r0inorm2 = (r0i * r0i).sum(-1)
        dipole_energy = ((-3 * m0dotr0i * midotr0i - r0inorm2 * m0dotmi) /
                         (4 * np.pi * r0inorm2**(5 / 2))).sum(-1)  # (bs,)
        return gpe + dipole_energy  #(bs,)

    def __str__(self):
        return f"{self.__class__}{self.arg_string}"

    def __repr__(self):
        return str(self)

    @property
    def animator(self):
        return MagnetPendulumAnimation
class Gyroscope(RigidBody):
    d=3 # Cartesian Embedding dimension
    D=3 #=3 euler coordinate dimension 
    angular_dims = range(3)
    n=4
    dt=0.02
    integration_time = 2
    def __init__(self, mass=.1, obj='gyro'):
        verts,tris =  read_obj(obj+'.obj')
        verts[:,2] -= verts[:,2].min() # set bottom as 0
        _,com,covar = compute_moments(torch.from_numpy(verts[tris]))
        print(torch.diag(torch.diag(covar).sum()*torch.eye(3)-covar),torch.diag(covar))
        self.obj = (verts,tris,com.numpy())
        self.body_graph =  BodyGraph()
        self.body_graph.add_extended_nd(0,m=mass,moments=100*torch.diag(covar))
        self.body_graph.add_joint(0,-com,pos2=torch.tensor([0.,0.,0.]))
    
    def sample_initial_conditions(self,N):
        # comEulers = torch.randn(N,2,6)
        # comEulers[:,1,:3]=0
        # comEulers[:,0,:3]=torch.tensor([0.,0.,1.])+.3*torch.randn(3)
        # comEulers[:,0,3:] = 1*torch.randn(3)
        # comEulers[:,1,3:] *=0#2#.5
        # comEulers[:,1,5] = 5
        # bodyX = comEuler2bodyX(comEulers)
        # try: return project_onto_constraints(self.body_graph,bodyX,tol=1e-5)
        # except OverflowError: return self.sample_initial_conditions(N)
        eulers = (torch.rand(N,2,3)-.5)*3
        #eulers[:,0,1]*=.2
        eulers[:,1,0]*=3
        eulers[:,1,1]*=.2
        eulers[:,1,2] = (torch.randint(2,size=(N,)).float()*2-1)*(torch.randn(N)+7)*1.5
        return self.body2globalCoords(eulers)

    def body2globalCoords(self,eulers):
        """ input: (bs,2,3) output: (bs,2,4,3) """
        coms = torch.zeros_like(eulers)
        comEulers = torch.cat([coms,eulers],dim=-1)
        bodyX = comEuler2bodyX(comEulers)
        # need to offset x,v so that joint is stationary
        # pos joint = 
        body_attachment = self.body_graph.nodes[0]['joint'][0].to(eulers.device,eulers.dtype)
        ct = torch.cat([1-body_attachment.sum()[None],body_attachment])
        global_coords_attachment_point = (bodyX*ct[:,None]).sum(-2,keepdims=True) #(bs,2,3)
        return bodyX-global_coords_attachment_point

    def global2bodyCoords(self,bodyX):
        """ input: (bs,2,4,3) output: (bs,2,3)"""
        eulers = bodyX2comEuler(bodyX)[...,3:] # unwrap the euler angles
        eulers[:,0,:] = torch.from_numpy(np.unwrap(eulers[:,0,:].numpy(),axis=0)).to(bodyX.device,bodyX.dtype)
        # print(eulers[:,0])
        # assert False
        return eulers

    def potential(self, x):
        """ Gravity potential """
        return 9.81*(self.M @ x)[..., 2].sum(1)

    @property
    def animator(self):
        return RigidAnimation
class ChainPendulum(RigidBody):
    d = 2
    dt = .03
    integration_time = 3

    def __init__(self, links=2, beams=False, m=None, l=None):
        self.body_graph = BodyGraph()  #nx.Graph()
        self.arg_string = f"n{links}{'b' if beams else ''}m{m or 'r'}l{l or 'r'}"
        assert not beams, "beams temporarily not supported"
        with FixedNumpySeed(0):
            ms = [.6 + .8 * np.random.rand()
                  for _ in range(links)] if m is None else links * [m]
            ls = [.6 + .8 * np.random.rand()
                  for _ in range(links)] if l is None else links * [l]
        self.ms = copy.deepcopy(ms)
        self.body_graph.add_extended_nd(0,
                                        m=ms.pop(),
                                        d=0,
                                        tether=(torch.zeros(2), ls.pop()))
        for i in range(1, links):
            self.body_graph.add_extended_nd(i, m=ms.pop(), d=0)
            self.body_graph.add_edge(i - 1, i, l=ls.pop())
        self.D = self.n = links
        self.angular_dims = range(links)

    def body2globalCoords(self, angles_omega):
        d = 2
        n = len(self.body_graph.nodes)
        N = angles_omega.shape[0]
        pvs = torch.zeros(N,
                          2,
                          n,
                          d,
                          device=angles_omega.device,
                          dtype=angles_omega.dtype)
        global_position_velocity = torch.zeros(N,
                                               2,
                                               d,
                                               device=angles_omega.device,
                                               dtype=angles_omega.dtype)
        length = self.body_graph.nodes[0]["tether"][1]
        global_position_velocity[:, 0, :] = self.body_graph.nodes[0]["tether"][
            0][None]
        global_position_velocity += self.joint2cartesian(
            length, angles_omega[..., 0])
        pvs[:, :, 0] = global_position_velocity
        for (_, j), length in nx.get_edge_attributes(self.body_graph,
                                                     "l").items():
            global_position_velocity += self.joint2cartesian(
                length, angles_omega[..., j])
            pvs[:, :, j] = global_position_velocity
        return pvs

    def joint2cartesian(self, length, angle_omega):
        position_vel = torch.zeros(angle_omega.shape[0],
                                   2,
                                   2,
                                   device=angle_omega.device,
                                   dtype=angle_omega.dtype)
        position_vel[:, 0, 0] = length * angle_omega[:, 0].sin()
        position_vel[:, 1,
                     0] = length * angle_omega[:, 0].cos() * angle_omega[:, 1]
        position_vel[:, 0, 1] = -length * angle_omega[:, 0].cos()
        position_vel[:, 1,
                     1] = length * angle_omega[:, 0].sin() * angle_omega[:, 1]
        return position_vel

    def cartesian2angle(self, rel_pos_vel):
        x, y = rel_pos_vel[:, 0].T
        vx, vy = rel_pos_vel[:, 1].T
        angle = torch.atan2(x, -y)
        omega = torch.where(angle < 1e-2, vx / (-y), vy / x)
        angle_unwrapped = torch.from_numpy(np.unwrap(angle.numpy(),
                                                     axis=0)).to(
                                                         x.device, x.dtype)
        return torch.stack([angle_unwrapped, omega], dim=1)

    def global2bodyCoords(self, global_pos_vel):
        N, _, n, d = global_pos_vel.shape
        *bsT2, n, d = global_pos_vel.shape
        angles_omega = torch.zeros(*bsT2,
                                   n,
                                   device=global_pos_vel.device,
                                   dtype=global_pos_vel.dtype)
        start_position_velocity = torch.zeros(*bsT2,
                                              d,
                                              device=angles_omega.device,
                                              dtype=angles_omega.dtype)
        start_position_velocity[
            ..., 0, :] = self.body_graph.nodes[0]["tether"][0][None]
        rel_pos_vel = global_pos_vel[..., 0, :] - start_position_velocity
        angles_omega[..., 0] += self.cartesian2angle(rel_pos_vel)
        start_position_velocity += rel_pos_vel
        for (_, j), length in nx.get_edge_attributes(self.body_graph,
                                                     "l").items():
            rel_pos_vel = global_pos_vel[..., j, :] - start_position_velocity
            angles_omega[..., j] += self.cartesian2angle(rel_pos_vel)
            start_position_velocity += rel_pos_vel
        return angles_omega

    def sample_initial_conditions(self, N):
        n = len(self.body_graph.nodes)
        angles_and_angvel = torch.randn(N, 2, n)  # (N,2,n)
        z = self.body2globalCoords(angles_and_angvel)
        #z = torch.randn(N,2,n,2)
        z[:, 0] += .2 * torch.randn(N, n, 2)
        z[:, 1] = (.5 * z[:, 1] + .4 * torch.randn(N, n, 2)) * 3
        try:
            return project_onto_constraints(self.body_graph, z, tol=1e-5)
        except OverflowError:
            return self.sample_initial_conditions(N)

        # return

    def potential(self, x):
        """ Gravity potential """
        return 9.81 * (self.M @ x)[..., 1].sum(1)

    def __str__(self):
        return f"{self.__class__}{self.arg_string}"

    def __repr__(self):
        return str(self)

    @property
    def animator(self):
        return PendulumAnimation