Exemplo n.º 1
0
    def fit(self, keypoints3d, dtype='coco', verbose=True):
        """Run fitting to optimize the SMPL parameters."""
        assert dtype == 'coco', 'only support coco format for now.'
        assert len(
            keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]'
        mapping_target = unify_joint_mappings(dataset=dtype)
        keypoints3d = keypoints3d[:, mapping_target, :]
        keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device)
        batch_size, njoints = keypoints3d.shape[0:2]

        # Init learnable smpl model
        smpl = SMPL(model_path=self.smpl_model_path,
                    gender=self.smpl_model_gender,
                    batch_size=batch_size).to(self.device)

        # Start fitting
        for step in range(self.niter):
            optimizer = self.get_optimizer(smpl, step, self.base_lr)

            output = smpl.forward()
            joints = output.joints[:, self.joints_mapping_smpl[:njoints], :]
            loss = self.metric(joints, keypoints3d)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if verbose and step % 10 == 0:
                logging.info(f'step {step:03d}; loss {loss.item():.3f};')

        # Return results
        return smpl, loss.item()
Exemplo n.º 2
0
def main(_):
  # Parsing data info.
  aist_dataset = AISTDataset(FLAGS.anno_dir)
  video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
  seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
  view_idx = AISTDataset.VIEWS.index(view)

  # Parsing keypoints.
  if FLAGS.mode == '2D':  # raw keypoints detection results.
    keypoints2d, _, _ = AISTDataset.load_keypoint2d(
        aist_dataset.keypoint2d_dir, seq_name)
    keypoints2d = keypoints2d[view_idx, :, :, 0:2]

  elif FLAGS.mode == '3D':  # 3D keypoints with temporal optimization.
    keypoints3d = AISTDataset.load_keypoint3d(
        aist_dataset.keypoint3d_dir, seq_name, use_optim=True)
    nframes, njoints, _ = keypoints3d.shape
    env_name = aist_dataset.mapping_seq2env[seq_name]
    cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
    keypoints2d = cgroup.project(keypoints3d)
    keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

  elif FLAGS.mode == 'SMPL':  # SMPL joints
    smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
        aist_dataset.motion_dir, seq_name)
    smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
    keypoints3d = smpl.forward(
        global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
        body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
        transl=torch.from_numpy(smpl_trans).float(),
        scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).joints.detach().numpy()

    nframes, njoints, _ = keypoints3d.shape
    env_name = aist_dataset.mapping_seq2env[seq_name]
    cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir, env_name)
    keypoints2d = cgroup.project(keypoints3d)
    keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

  # Visualize.
  os.makedirs(FLAGS.save_dir, exist_ok=True)
  save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4')
  plot_on_video(keypoints2d, video_path, save_path, fps=60)
Exemplo n.º 3
0
    def fit(self, keypoints3d, dtype='coco', verbose=True):
        """Run fitting to optimize the SMPL parameters."""
        assert dtype == 'coco', 'only support coco format for now.'
        assert len(
            keypoints3d.shape) == 3, 'input shape should be [N, njoints, 3]'
        mapping_target = unify_joint_mappings(dataset=dtype)
        keypoints3d = keypoints3d[:, mapping_target, :]
        keypoints3d = torch.from_numpy(keypoints3d).float().to(self.device)
        batch_size, njoints = keypoints3d.shape[0:2]

        # Init learnable smpl model
        smpl = SMPL(model_path=self.smpl_model_path,
                    gender=self.smpl_model_gender,
                    batch_size=batch_size).to(self.device)

        # Start fitting
        for step in range(self.niter):
            optimizer = self.get_optimizer(smpl, step, self.base_lr)

            output = smpl.forward()
            joints = output.joints[:, self.joints_mapping_smpl[:njoints], :]
            loss = self.metric(joints, keypoints3d)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if verbose and step % 10 == 0:
                logging.info(f'step {step:03d}; loss {loss.item():.3f};')

            if FLAGS.visualize:
                vertices = output.vertices[0].detach().cpu().numpy(
                )  # first frame
                mesh = trimesh.Trimesh(vertices, smpl.faces)
                mesh.visual.face_colors = [200, 200, 250, 100]
                pts = vedo.Points(keypoints3d[0].detach().cpu().numpy(),
                                  r=20)  # first frame
                vedo.show(mesh, pts, interactive=False)

        # Return results
        return smpl, loss.item()
def load_dance_data(dance_dir):
    print('---------- Loading pose keypoints ----------')
    aist_dataset = AISTDataset(dance_dir)
    seq_names = list(aist_dataset.mapping_seq2env.keys())
    print(seq_names)

    dances = {}

    for seq_name in tqdm(seq_names):
        print(f'Process -> {seq_name}')
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=args.smpl_dir, gender='MALE', batch_size=1)
        keypoints3d = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans / smpl_scaling).float(),
        ).joints.detach().numpy()[:, 0:24, :]
        nframes = keypoints3d.shape[0]
        dances[seq_name] = keypoints3d.reshape(nframes, -1).tolist()
        print(np.shape(dances[seq_name]))  # (nframes, 72)

    return dances
Exemplo n.º 5
0
def main():
    args = get_args()
    if args.aist:
        import vedo

    if not os.path.exists(args.json_dir):
        os.makedirs(args.json_dir)

    if args.aist:
        print ("test with AIST++ dataset!")
        music_data, dance_data, dance_names = load_data_aist(
            args.input_dir, interval=None, rotmat=args.rotmat)
    else:    
        music_data, dance_data, dance_names = load_data(
            args.input_dir, interval=None)

    device = torch.device('cuda' if args.cuda else 'cpu')

    test_loader = torch.utils.data.DataLoader(
        DanceDataset(music_data, dance_data),
        batch_size=args.batch_size,
        collate_fn=paired_collate_fn
    )

    generator = Generator(args.model, device)
    
    if args.aist and args.rotmat:
        from smplx import SMPL
        smpl = SMPL(model_path="/media/ruilongli/hd1/Data/smpl/", gender='MALE', batch_size=1)

    results = []
    random_id = 0  # np.random.randint(0, 1e4)
    for i, batch in enumerate(tqdm(test_loader, desc='Generating dance poses')):
        # Prepare data
        src_seq, src_pos, tgt_pose = map(lambda x: x.to(device), batch)
        pose_seq = generator.generate(src_seq[:, :1200], src_pos[:, :1200])  # first 20 secs
        results.append(pose_seq)

        if args.aist:
            np_dance = pose_seq[0].data.cpu().numpy()
            if args.rotmat:
                root = np_dance[:, :3]
                rotmat = np_dance[:, 3:].reshape([-1, 3, 3])
                rotmat = get_closest_rotmat(rotmat)
                smpl_poses = rotmat2aa(rotmat).reshape(-1, 24, 3)
                np_dance = smpl.forward(
                    global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
                    body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
                    transl=torch.from_numpy(root).float(),
                ).joints.detach().numpy()[:, 0:24, :]
            else:
                root = np_dance[:, :3]
                np_dance = np_dance + np.tile(root, (1, 24))
                np_dance[:, :3] = root
                np_dance = np_dance.reshape(np_dance.shape[0], -1, 3)
            print (np_dance.shape)
            # save
            save_path = os.path.join(args.json_dir, dance_names[i]+f"_{random_id:04d}")
            np.save(save_path, np_dance)
            # visualize
            for frame in np_dance:
                pts = vedo.Points(frame, r=20)
                vedo.show(pts, interactive=False)
                time.sleep(0.1)
            exit()

    if args.aist:
        pass

    else:
        # Visualize generated dance poses
        np_dances = []
        for i in range(len(results)):
            np_dance = results[i][0].data.cpu().numpy()
            root = np_dance[:, 2*8:2*9]
            np_dance = np_dance + np.tile(root, (1, 25))
            np_dance[:, 2*8:2*9] = root
            np_dances.append(np_dance)
        write2json(np_dances, dance_names, args)
        visualize(args)
Exemplo n.º 6
0
def main(_):
    # Parsing data info.
    aist_dataset = AISTDataset(FLAGS.anno_dir)
    video_path = os.path.join(FLAGS.video_dir, f'{FLAGS.video_name}.mp4')
    seq_name, view = AISTDataset.get_seq_name(FLAGS.video_name)
    view_idx = AISTDataset.VIEWS.index(view)

    # Parsing keypoints.
    if FLAGS.mode == '2D':  # raw keypoints detection results.
        keypoints2d, _, _ = AISTDataset.load_keypoint2d(
            aist_dataset.keypoint2d_dir, seq_name)
        keypoints2d = keypoints2d[view_idx, :, :, 0:2]

    elif FLAGS.mode == '3D':  # 3D keypoints with temporal optimization.
        keypoints3d = AISTDataset.load_keypoint3d(aist_dataset.keypoint3d_dir,
                                                  seq_name,
                                                  use_optim=True)
        nframes, njoints, _ = keypoints3d.shape
        env_name = aist_dataset.mapping_seq2env[seq_name]
        cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir,
                                               env_name)
        keypoints2d = cgroup.project(keypoints3d)
        keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

    elif FLAGS.mode == 'SMPL':  # SMPL joints
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        keypoints3d = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans).float(),
            scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).joints.detach().numpy()

        nframes, njoints, _ = keypoints3d.shape
        env_name = aist_dataset.mapping_seq2env[seq_name]
        cgroup = AISTDataset.load_camera_group(aist_dataset.camera_dir,
                                               env_name)
        keypoints2d = cgroup.project(keypoints3d)
        keypoints2d = keypoints2d.reshape(9, nframes, njoints, 2)[view_idx]

    elif FLAGS.mode == 'SMPLMesh':  # SMPL Mesh
        import trimesh  # install by `pip install trimesh`
        import vedo  # install by `pip install vedo`
        smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
            aist_dataset.motion_dir, seq_name)
        smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
        vertices = smpl.forward(
            global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
            body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
            transl=torch.from_numpy(smpl_trans).float(),
            scaling=torch.from_numpy(smpl_scaling.reshape(1, 1)).float(),
        ).vertices.detach().numpy()[0]  # first frame
        faces = smpl.faces
        mesh = trimesh.Trimesh(vertices, faces)
        mesh.visual.face_colors = [200, 200, 250, 100]

        keypoints3d = AISTDataset.load_keypoint3d(aist_dataset.keypoint3d_dir,
                                                  seq_name,
                                                  use_optim=True)
        pts = vedo.Points(keypoints3d[0], r=20)  # first frame

        vedo.show(mesh, pts, interactive=True)
        exit()

    # Visualize.
    os.makedirs(FLAGS.save_dir, exist_ok=True)
    save_path = os.path.join(FLAGS.save_dir, f'{FLAGS.video_name}.mp4')
    plot_on_video(keypoints2d, video_path, save_path, fps=60)