예제 #1
0
파일: fit_SMPLX.py 프로젝트: MoyGcc/IPNet
def forward_step(th_scan_meshes, smpl, th_pose_3d=None):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smpl.gender)

    # forward
    verts, _, _, _ = smpl()
    th_smpl_meshes = [
        tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts
    ]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(
        [sm.vertices for sm in th_scan_meshes], th_smpl_meshes)
    loss['m2s'] = batch_point_to_surface(
        [sm.vertices for sm in th_smpl_meshes], th_scan_meshes)
    loss['betas'] = torch.mean(smpl.betas**2, axis=1)
    loss['pose_pr'] = prior(smpl.pose)
    if th_pose_3d is not None:
        loss['pose_obj'] = batch_get_pose_obj(th_pose_3d, smpl)
    return loss
예제 #2
0
def forward_step(th_scan_meshes,
                 smplx,
                 scan_part_labels,
                 smplx_part_labels,
                 search_tree=None,
                 pen_distance=None,
                 tri_filtering_module=None):
    """
    Performs a forward step, given smplx and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smplx.gender, precomputed=True)

    # forward
    # verts, _, _, _ = smplx()
    verts = smplx()
    th_smplx_meshes = [
        tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts
    ]

    scan_verts = [sm.vertices for sm in th_scan_meshes]
    smplx_verts = [sm.vertices for sm in th_smplx_meshes]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(scan_verts, th_smplx_meshes)
    loss['m2s'] = batch_point_to_surface(smplx_verts, th_scan_meshes)
    loss['betas'] = torch.mean(smplx.betas**2, axis=1)
    # loss['pose_pr'] = prior(smplx.pose)
    loss['interpenetration'] = interpenetration_loss(verts, smplx.faces,
                                                     search_tree, pen_distance,
                                                     tri_filtering_module, 1.0)
    loss['part'] = []
    for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)):
        tot = 0
        for i in range(NUM_PARTS):  # we currently use 14 parts
            if i not in sc_l:
                continue
            ind = torch.where(sc_l == i)[0]
            sc_part_points = sc_v[ind].unsqueeze(0)
            sm_part_points = smplx_verts[n][torch.where(
                smplx_part_labels[n] == i)[0]].unsqueeze(0)
            dist = chamfer_distance(sc_part_points,
                                    sm_part_points,
                                    w1=1.,
                                    w2=1.)
            tot += dist
        loss['part'].append(tot / NUM_PARTS)
    loss['part'] = torch.stack(loss['part'])
    return loss
예제 #3
0
def forward_step_SMPLD(th_scan_meshes, smpl, init_smpl_meshes, args):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """

    # forward
    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v,
                                      faces=smpl.faces) for v in verts]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface([sm.vertices for sm in th_scan_meshes], th_smpl_meshes)
    loss['m2s'] = batch_point_to_surface([sm.vertices for sm in th_smpl_meshes], th_scan_meshes)
    loss['lap'] = torch.stack([laplacian_loss(sc, sm) for sc, sm in zip(init_smpl_meshes, th_smpl_meshes)])
    loss['offsets'] = torch.mean(torch.mean(smpl.offsets**2, axis=1), axis=1)
    return loss
예제 #4
0
def forward_step_SMPL(th_scan_meshes, smpl, scan_part_labels, smpl_part_labels, args):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smpl.gender, precomputed=True)

    # forward
    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v,
                                      faces=smpl.faces) for v in verts]

    scan_verts = [sm.vertices for sm in th_scan_meshes]
    smpl_verts = [sm.vertices for sm in th_smpl_meshes]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(scan_verts, th_smpl_meshes)
    loss['m2s'] = batch_point_to_surface(smpl_verts, th_scan_meshes)
    loss['betas'] = torch.mean(smpl.betas ** 2, axis=1)
    loss['pose_pr'] = prior(smpl.pose)

    # if args.num_joints == 14:
    if args.use_parts:
        loss['part'] = []
        for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)):
            tot = 0
            # for i in range(args.num_joints):  # we currently use 14 parts
            for i in range(14):  # we currently use 14 parts
                if i not in sc_l:
                    continue
                ind = torch.where(sc_l == i)[0]
                sc_part_points = sc_v[ind].unsqueeze(0)
                sm_part_points = smpl_verts[n][torch.where(smpl_part_labels[n] == i)[0]].unsqueeze(0)
                dist = chamfer_distance(sc_part_points, sm_part_points, w1=1., w2=1.)
                tot += dist
            # loss['part'].append(tot / args.num_joints)
            loss['part'].append(tot / 14)

        loss['part'] = torch.stack(loss['part'])

    return loss