Ejemplo n.º 1
0
 def __init__(
     self,
     image_size,
     model,
     fill_back=True,
     use_backward=True,
     lambda_data=1,
     lambda_consist=1,
     criterion="l1",
     consist_scale=1,
     first_only=True,
     gt_refs=True,
     progressive_consist=True,
     progressive_steps=1000,
 ):
     super().__init__()
     self.fill_back = fill_back
     self.use_backward = use_backward
     max_size = max(image_size)
     self.image_size = image_size
     self.lambda_data = lambda_data
     self.lambda_consist = lambda_consist
     self.consist_scale = consist_scale
     self.criterion = pyramidloss.PyramidCriterion(criterion)
     self.first_only = first_only
     self.progressive_consist = progressive_consist
     self.progressive_steps = progressive_steps
     self.gt_refs = gt_refs
     self.step_count = 0
     neurenderer = renderer.Renderer(
         image_size=max_size,
         R=torch.eye(3).unsqueeze(0).cuda(),
         t=torch.zeros(1, 3).cuda(),
         K=torch.ones(1, 3, 3).cuda(),
         orig_size=max_size,
         anti_aliasing=False,
         fill_back=fill_back,
         near=0.1,
         no_light=True,
         light_intensity_ambient=0.8,
     )
     self.renderer = neurenderer
     self.model = model
     self.mano_layer = manolayer.ManoLayer(
         joint_rot_mode="axisang",
         use_pca=False,
         mano_root="assets/mano",
         center_idx=None,
         flat_hand_mean=True,
     )
     closed_faces, hand_ignore_faces = manoutils.get_closed_faces()
     self.hand_ignore_faces = hand_ignore_faces
     self.mano_layer.register_buffer("th_faces", closed_faces)
     self.fill_back = fill_back
Ejemplo n.º 2
0
def main(args):
    setseeds.set_all_seeds(args.manual_seed)
    # Initialize hosting
    exp_id = f"checkpoints/{args.dataset}/" f"{args.com}"

    # Initialize local checkpoint folder
    print(f"Saving info about experiment at {exp_id}")
    save_args(args, exp_id, "opt")
    render_folder = os.path.join(exp_id, "images")
    os.makedirs(render_folder, exist_ok=True)
    # Load models
    models = []
    for resume in args.resumes:
        print('Resuming', resume)
        opts = reloadmodel.load_opts(resume)
        model, epoch = reloadmodel.reload_model(resume, opts)
        models.append(model)
        freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    dataset, input_res = get_dataset.get_dataset(args.dataset, split=args.split, meta={"version": args.version, "split_mode": args.split_mode},
                                                 mode=args.mode, use_cache=args.use_cache,
                                                 no_augm=True, center_idx=opts["center_idx"], sample_nb=None)

    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False,
                                         collate_fn=collate.meshreg_collate)


    all_samples = []

    # Put models on GPU and evaluation mode
    for model in models:
        model.cuda()
        model.eval()

    i = 0
    for batch in tqdm(loader):  # Loop over batches
        all_results = []
        # Compute model outputs
        with torch.no_grad():
            for model in models:
                _, results, _ = model(batch)
                all_results.append(results)

        # Densely render error map for the meshes
        # for results in all_results:
        #     render_results, cmap_obj = fastrender.comp_render(
        #         batch, all_results, rotate=True, modes=("all", "obj", "hand"), max_val=args.max_val
        #     )
        
        # if i > 100:
        #     break
        # i += 1

        for img_idx, img in enumerate(batch[BaseQueries.IMAGE]):    # Each batch has 4 images
            network_out = all_results[0]
            sample_idx = batch['idx'][img_idx]
            handpose, handtrans, handshape = dataset.pose_dataset.get_hand_info(sample_idx)

            sample_dict = dict()
            # sample_dict['image'] = img
            sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][img_idx, :, :]
            sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][img_idx, :, :]
            sample_dict['hand_faces'], _ = manoutils.get_closed_faces()
            sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][img_idx, :, :]

            sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][img_idx, :, :]
            sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][img_idx, :, :]
            # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :]
            sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :]
            sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :]
            sample_dict['side'] = batch[BaseQueries.SIDE][img_idx]

            sample_dict['hand_pose_gt'] = handpose
            sample_dict['hand_beta_gt'] = handshape
            sample_dict['hand_trans_gt'] = handtrans
            sample_dict['hand_extr_gt'] = cam_extr = torch.Tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

            for k in sample_dict.keys():
                sample_dict[k] = to_cpu_npy(sample_dict[k])

            all_samples.append(sample_dict)

            continue

            # obj_verts_gt = samplevis.get_check_none(sample, BaseQueries.OBJVERTS3D, cpu=False)
            # hand_verts_gt = samplevis.get_check_none(sample, BaseQueries.HANDVERTS3D, cpu=False)
            # hand_verts = samplevis.get_check_none(results, "recov_handverts3d", cpu=False)
            # obj_verts = samplevis.get_check_none(results, "recov_objverts3d", cpu=False)
            # obj_faces = samplevis.get_check_none(sample, BaseQueries.OBJFACES, cpu=False).long()
            # hand_faces, _ = manoutils.get_closed_faces()


            # plt.imshow(img)
            # plt.show()

            for k in batch.keys():
                elem = batch[k]
                # if isinstance(elem, list):
                #     s = len(s)
                if isinstance(elem, torch.Tensor):
                    s = elem.shape
                else:
                    s = elem
                print('{}: Shape {}'.format(k, s))

            for k in all_results[0].keys():
                elem = all_results[0][k]
                if isinstance(elem, list):
                    s = len(s)
                elif isinstance(elem, torch.Tensor):
                    s = elem.shape
                else:
                    s = elem
                print('Network out {}: Shape {}'.format(k, s))

            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            ax.plot_trisurf(sample_dict['obj_verts_gt'][:,0], sample_dict['obj_verts_gt'][:,1], sample_dict['obj_verts_gt'][:,2],
                            triangles=sample_dict['obj_faces'])

            ax.plot_trisurf(sample_dict['hand_verts_gt'][:,0], sample_dict['hand_verts_gt'][:,1], sample_dict['hand_verts_gt'][:,2],
                            triangles=sample_dict['hand_faces'])

            ax.plot_trisurf(sample_dict['obj_verts_pred'][:,0], sample_dict['obj_verts_pred'][:,1], sample_dict['obj_verts_pred'][:,2],
                            triangles=sample_dict['obj_faces'])

            ax.plot_trisurf(sample_dict['hand_verts_pred'][:,0], sample_dict['hand_verts_pred'][:,1], sample_dict['hand_verts_pred'][:,2],
                            triangles=sample_dict['hand_faces'])

            plt.show()

    print('Saving final dict', len(all_samples))
    with open('all_samples.pkl', 'wb') as handle:
        pickle.dump(all_samples, handle)
    print('Done saving')
Ejemplo n.º 3
0
def main(args):
    torch.cuda.manual_seed_all(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)
    # Initialize hosting
    dat_str = args.val_dataset
    now = datetime.now()
    exp_id = (
        f"checkpoints/{dat_str}_mini{args.mini_factor}/"
        f"{now.year}_{now.month:02d}_{now.day:02d}/"
        f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_"
        f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}")

    # Initialize local checkpoint folder
    save_args(args, exp_id, "opt")
    result_folder = os.path.join(exp_id, "results")
    os.makedirs(result_folder, exist_ok=True)
    pyapt_path = os.path.join(result_folder,
                              f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
    with open(pyapt_path, "a") as t_f:
        t_f.write(" ")

    val_dataset, input_size = get_dataset.get_dataset(
        args.val_dataset,
        split=args.val_split,
        meta={
            "version": args.version,
            "split_mode": args.split_mode
        },
        use_cache=args.use_cache,
        mini_factor=args.mini_factor,
        mode=args.mode,
        fraction=args.fraction,
        no_augm=True,
        center_idx=args.center_idx,
        scale_jittering=0,
        center_jittering=0,
        sample_nb=None,
        has_dist2strong=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False,
        collate_fn=collate.meshreg_collate,
    )

    opts = reloadmodel.load_opts(args.resume)
    model, epoch = reloadmodel.reload_model(args.resume, opts)
    freeze.freeze_batchnorm_stats(model)  # Freeze batchnorm

    all_samples = []

    model.eval()
    model.cuda()
    for batch_idx, batch in enumerate(tqdm(val_loader)):
        all_results = []
        with torch.no_grad():
            loss, results, losses = model(batch)
            all_results.append(results)

        for img_idx, img in enumerate(
                batch[BaseQueries.IMAGE]):  # Each batch has 4 images
            network_out = all_results[0]
            sample_idx = batch['idx'][img_idx]
            handpose, handtrans, handshape = val_dataset.pose_dataset.get_hand_info(
                sample_idx)

            sample_dict = dict()
            # sample_dict['image'] = img
            sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][
                img_idx, :, :]
            sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][
                img_idx, :, :]
            sample_dict['hand_faces'], _ = manoutils.get_closed_faces()
            sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][
                img_idx, :, :]

            sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][
                img_idx, :, :]
            sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][
                img_idx, :, :]
            # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :]
            sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :]
            sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :]
            sample_dict['side'] = batch[BaseQueries.SIDE][img_idx]

            sample_dict['hand_pose_gt'] = handpose
            sample_dict['hand_beta_gt'] = handshape
            sample_dict['hand_trans_gt'] = handtrans
            sample_dict['hand_extr_gt'] = torch.Tensor([[1, 0, 0, 0],
                                                        [0, -1, 0, 0],
                                                        [0, 0, -1, 0],
                                                        [0, 0, 0, 1]])

            for k in sample_dict.keys():
                sample_dict[k] = to_cpu_npy(sample_dict[k])

            all_samples.append(sample_dict)

    print('Saving final dict', len(all_samples))
    with open('all_samples.pkl', 'wb') as handle:
        pickle.dump(all_samples, handle)
    print('Done saving')
Ejemplo n.º 4
0
def comp_render(
        sample,
        all_results,
        fill_back=True,
        near=0.05,
        far=2,
        modes=("all"),
        rotate=True,
        crop_to_img=False,
        max_val=0.1,
):
    images = sample[TransQueries.IMAGE].permute(0, 2, 3, 1).cpu() + 0.5
    camintrs = sample[TransQueries.CAMINTR].cuda()
    hand_faces, _ = manoutils.get_closed_faces()
    batch_size = images.shape[0]
    hand_faces_b = hand_faces.unsqueeze(0).repeat(batch_size, 1,
                                                  1).long().cuda()
    all_hand_verts = []
    all_obj_verts = []
    all_obj_faces = []
    for results in all_results:
        hand_verts = samplevis.get_check_none(results,
                                              "recov_handverts3d",
                                              cpu=False)
        obj_verts = samplevis.get_check_none(results,
                                             "recov_objverts3d",
                                             cpu=False)
        obj_faces = samplevis.get_check_none(sample,
                                             BaseQueries.OBJFACES,
                                             cpu=False).long()
        all_hand_verts.append(hand_verts)
        all_obj_verts.append(obj_verts)
        all_obj_faces.append(obj_faces)

    obj_verts_gt = samplevis.get_check_none(sample,
                                            BaseQueries.OBJVERTS3D,
                                            cpu=False)
    hand_verts_gt = samplevis.get_check_none(sample,
                                             BaseQueries.HANDVERTS3D,
                                             cpu=False)
    hand_colors_gt = consistdisplay.get_verts_colors(hand_verts_gt,
                                                     [0.3, 0.3, 1])
    obj_colors_gt = consistdisplay.get_verts_colors(obj_verts_gt,
                                                    [1, 0.3, 0.3])
    all_colors_gt = torch.cat([hand_colors_gt, obj_colors_gt], 1)
    # Render
    jet_cmap = cm.get_cmap("jet")
    all_obj_errs = torch.stack(
        [pred_vert - obj_verts_gt for pred_vert in all_obj_verts]).norm(2, -1)
    all_hand_errs = torch.stack([
        pred_vert - hand_verts_gt for pred_vert in all_hand_verts
    ]).norm(2, -1)

    all_obj_colors = all_obj_errs / max_val
    all_hand_colors = all_hand_errs / max_val
    cmap_objs = jet_cmap(all_obj_colors.cpu().numpy())
    cmap_hands = jet_cmap(all_hand_colors.cpu().numpy())
    obj_colors = obj_verts_gt.new(cmap_objs)
    hand_colors = obj_verts_gt.new(cmap_hands)
    all_colors = torch.cat([hand_colors, obj_colors], 2)
    input_res = (images.shape[2], images.shape[1])
    all_render_res = defaultdict(list)
    all_verts_gt, all_faces, _ = catmesh.batch_cat_meshes(
        [hand_verts_gt, obj_verts_gt], [hand_faces_b, obj_faces])
    # Render ground truth meshes
    for mode in modes:
        if mode == "all":
            render_res = render(
                all_verts_gt,
                all_faces,
                input_res,
                camintrs=camintrs,
                colors=all_colors_gt,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
            if rotate:
                rotated_verts = rotateverts.rotate_verts(all_verts_gt)
                renderot_res = render(
                    rotated_verts,
                    all_faces,
                    input_res,
                    camintrs=camintrs,
                    colors=all_colors_gt,
                    near=near,
                    far=far,
                    fill_back=fill_back,
                    crop_to_img=crop_to_img,
                )
        elif mode == "hand":
            render_res = render(
                hand_verts_gt,
                hand_faces_b,
                input_res,
                camintrs=camintrs,
                colors=hand_colors_gt,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
        elif mode == "obj":
            render_res = render(
                obj_verts_gt,
                obj_faces,
                input_res,
                camintrs=camintrs,
                colors=obj_colors_gt,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
        all_render_res[mode].append(render_res.cpu())
        all_render_res[f"{mode}_rotated"].append(renderot_res.cpu())
    # Render predictions
    for model_idx, (hand_verts, obj_verts, obj_faces) in enumerate(
            zip(all_hand_verts, all_obj_verts, all_obj_faces)):
        all_verts, all_faces, _ = catmesh.batch_cat_meshes(
            [hand_verts, obj_verts], [hand_faces_b, obj_faces])
        for mode in modes:
            if mode == "all":
                render_res = render(
                    all_verts,
                    all_faces,
                    input_res,
                    camintrs=camintrs,
                    colors=all_colors[model_idx],
                    near=near,
                    far=far,
                    fill_back=fill_back,
                    crop_to_img=crop_to_img,
                )
                if rotate:
                    rotated_verts = rotateverts.rotate_verts(all_verts)
                    renderot_res = render(
                        rotated_verts,
                        all_faces,
                        input_res,
                        camintrs=camintrs,
                        colors=all_colors[model_idx],
                        near=near,
                        far=far,
                        fill_back=fill_back,
                        crop_to_img=crop_to_img,
                    )
            elif mode == "obj":
                render_res = render(
                    obj_verts,
                    obj_faces,
                    input_res,
                    camintrs=camintrs,
                    colors=obj_colors[model_idx],
                    near=near,
                    far=far,
                    fill_back=fill_back,
                    crop_to_img=crop_to_img,
                )
            elif mode == "hand":
                render_res = render(
                    hand_verts,
                    hand_faces_b,
                    input_res,
                    camintrs=camintrs,
                    # colors=hand_colors[model_idx],
                    colors=hand_colors_gt,
                    near=near,
                    far=far,
                    fill_back=fill_back,
                    crop_to_img=crop_to_img,
                )
            all_render_res[f"{mode}"].append(render_res.cpu())
            all_render_res[f"{mode}_rotated"].append(renderot_res.cpu())
    return all_render_res, cmap_objs
Ejemplo n.º 5
0
def hand_obj_render(
        sample,
        results,
        hand_colors=(0.3, 0.3, 1),
        obj_colors=(1, 0.3, 0.3),
        fill_back=True,
        near=0.05,
        far=2,
        modes=("all"),
        rotate=True,
        crop_to_img=True,
):
    images = sample[TransQueries.IMAGE].permute(0, 2, 3, 1).cpu() + 0.5
    camintrs = sample[TransQueries.CAMINTR].cuda()
    batch_size = images.shape[0]
    hand_verts = samplevis.get_check_none(results,
                                          "recov_handverts3d",
                                          cpu=False)
    obj_verts = samplevis.get_check_none(results,
                                         "recov_objverts3d",
                                         cpu=False)
    obj_faces = samplevis.get_check_none(sample,
                                         BaseQueries.OBJFACES,
                                         cpu=False)

    # Render
    render_results = {}
    if hand_verts is not None:
        # Initialize faces and textures, TODO use closed hand compatible with model vertices
        hand_faces, _ = manoutils.get_closed_faces()
        hand_faces_b = hand_faces.unsqueeze(0).repeat(batch_size, 1,
                                                      1).long().cuda()
        input_res = (images.shape[2], images.shape[1])
        hand_colors = consistdisplay.get_verts_colors(hand_verts,
                                                      [0.3, 0.3, 1])
        if "hand" in modes:
            render_res = render(
                hand_verts,
                hand_faces_b,
                input_res,
                camintrs=camintrs,
                colors=hand_colors,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
            render_results["hand"] = render_res
    if obj_verts is not None:
        obj_colors = consistdisplay.get_verts_colors(obj_verts, [1, 0.3, 0.3])
        obj_faces = obj_faces.long()
        if "obj" in modes:
            render_res = render(
                obj_verts,
                obj_faces,
                input_res,
                camintrs=camintrs,
                colors=obj_colors,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
            render_results["obj"] = render_res
    if obj_verts is not None and hand_verts is not None:
        colors = torch.cat([hand_colors, obj_colors], 1)
        all_verts, all_faces, _ = catmesh.batch_cat_meshes(
            [hand_verts, obj_verts], [hand_faces_b, obj_faces])
        if "all" in modes:
            render_res = render(
                all_verts,
                all_faces,
                input_res,
                camintrs=camintrs,
                colors=colors,
                near=near,
                far=far,
                fill_back=fill_back,
                crop_to_img=crop_to_img,
            )
            render_results["all"] = render_res
            if rotate:
                rotated_verts = rotateverts.rotate_verts(all_verts)
                render_res = render(
                    rotated_verts,
                    all_faces,
                    input_res,
                    camintrs=camintrs,
                    colors=colors,
                    near=near,
                    far=far,
                    fill_back=fill_back,
                    crop_to_img=crop_to_img,
                )
                render_results["all_rotated"] = render_res
    return dict(render_results)