def handPoseMF(w, h, objParamInitList, handParamInitList, objMesh, camProp, out_dir):
    ds = tf.data.Dataset.from_generator(lambda: dataGen(w, h),
                                        (tf.string, tf.float32, tf.float32, tf.float32, tf.float32),
                                        ((None,), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3)))
    assert len(objParamInitList)==len(handParamInitList)

    numFrames = len(objParamInitList)

    # read real observations
    frameCntInt, loadData, realObservs = LossObservs.getRealObservables(ds, numFrames, w, h)
    icp = Icp(realObservs, camProp)

    # set up the scene
    scene = Scene(optModeEnum.MULTIFRAME_RIGID_HO_POSE, frameCnt=numFrames)
    objID = scene.addObject(objMesh, objParamInitList, segColor=objSegColor)
    handID = scene.addHand(handParamInitList, handSegColor, baseItemID=objID)
    scene.addCamera(f=camProp.f, c=camProp.c, near=camProp.near, far=camProp.far, frameSize=camProp.frameSize)
    finalMesh = scene.getFinalMesh()

    # render the scene
    renderer = DirtRenderer(finalMesh, renderModeEnum.SEG_COLOR_DEPTH)
    virtObservs = renderer.render()

    # get loss over observables
    observLoss = LossObservs(virtObservs, realObservs, renderModeEnum.SEG_COLOR_DEPTH)
    segLoss, depthLoss, colLoss = observLoss.getL2Loss(isClipDepthLoss=True, pyrLevel=2)

    # get parameters and constraints
    handConstrs = Constraints()
    paramListHand = scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT, varTypeEnum.HAND_ROT])
    jointAngs = paramListHand[0]
    handRot = paramListHand[1]
    validTheta = tf.concat([handRot, jointAngs], axis=0)
    theta = handConstrs.getFullThetafromValidTheta(validTheta)
    thetaConstrs, _ = handConstrs.getHandThetaConstraints(validTheta, isValidTheta=True)

    # some variables for vis and analysis
    paramListObj = scene.getParamsByItemID([parTypeEnum.OBJ_ROT, parTypeEnum.OBJ_TRANS, parTypeEnum.OBJ_POSE_MAT], objID)
    rotObj = paramListObj[0]
    transObj = paramListObj[1]
    poseMat = paramListObj[2]

    paramListHand = scene.getParamsByItemID([parTypeEnum.HAND_THETA, parTypeEnum.HAND_TRANS, parTypeEnum.HAND_BETA],
                                           handID)
    thetaMat = paramListHand[0]
    transHand = paramListHand[1]
    betaHand = paramListHand[2]


    # contact loss
    hoContact = ContactLoss(scene.itemPropsDict[objID].transformedMesh, scene.itemPropsDict[handID].transformedMesh,
                            scene.itemPropsDict[handID].transorfmedJs)
    contLoss = hoContact.getRepulsionLoss()

    # get icp losses
    icpLossHand = icp.getLoss(scene.itemPropsDict[handID].transformedMesh.v, handSegColor)
    icpLossObj = icp.getLoss(scene.itemPropsDict[objID].transformedMesh.v, objSegColor)

    # get rel hand obj pose loss
    handTransVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS_REL_DELTA]), axis=0)
    handRotVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_ROT_REL_DELTA]), axis=0)
    relPoseLoss = handConstrs.getHandObjRelDeltaPoseConstraint(handRotVars, handTransVars)

    if use2DJointLoss:
        # get 2d joint loss
        transJs = tf.reshape(scene.itemPropsDict[handID].transorfmedJs, [-1, 3])
        projJs = tfProjectPoints(camProp, transJs)
        projJs = tf.reshape(projJs, [numFrames, 21, 2])
        if handParamInitList[0].JTransformed[0, 2] < 0:
            isOpenGLCoords = True
        else:
            isOpenGLCoords = False
        joints2DGT = np.stack(
            [cv2ProjectPoints(camProp, hpi.JTransformed, isOpenGLCoords) for hpi in handParamInitList], axis=0)
        jointVisGT = np.stack([hpi.JVis for hpi in handParamInitList], axis=0)
        joints2DErr = tf.reshape(tf.reduce_sum(tf.square(projJs - joints2DGT), axis=2), [-1])
        joints2DErr = joints2DErr * tf.reshape(jointVisGT, [-1])#tf.boolean_mask(joints2DErr, tf.reshape(jointVisGT, [-1]))
        joints2DLoss = tf.reduce_sum(joints2DErr)
        wrist2DErr = tf.reshape(tf.reduce_sum(tf.square(projJs[:,0,:] - joints2DGT[:,0,:]), axis=1), [-1])
        wrist2DErr = wrist2DErr * tf.reshape(jointVisGT[:,0], [-1])
        wrist2DLoss = tf.reduce_sum(wrist2DErr)

    # get final loss
    icpWt = 5e3#1e3#1e2
    j2dWt = 0.#1e-5
    segWt = 5e1
    depWt = 1e1
    wristJWt = 1e-3#1e-1
    contactWt = 1e-1
    totalLoss1 = segWt*segLoss + depWt*depthLoss + 0.0*colLoss + 1e2*thetaConstrs + icpWt*icpLossHand + icpWt*icpLossObj + 1e6*relPoseLoss + contactWt*contLoss
    if use2DJointLoss:
        totalLoss1 = totalLoss1 + j2dWt*joints2DLoss + wristJWt*wrist2DLoss
    totalLoss2 = 1.15 * segLoss + 5.0 * depthLoss + 0.0*colLoss + 1e2 * thetaConstrs

    # get the variables for opt
    optVarsHandList = scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS, varTypeEnum.HAND_ROT,
                                                     # varTypeEnum.HAND_ROT_REL_DELTA, varTypeEnum.HAND_TRANS_REL_DELTA,
                                                     # varTypeEnum.HAND_JOINT,
                                                     ], [])
    optVarsHandDelta = scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS_REL_DELTA, varTypeEnum.HAND_ROT_REL_DELTA], [])
    optVarsHandJoint = scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT], [])
    optVarsHandBeta = scene.getVarsByItemID(handID, [varTypeEnum.HAND_BETA], [])
    optVarsObjList = scene.getVarsByItemID(objID, [varTypeEnum.OBJ_TRANS, varTypeEnum.OBJ_ROT], [])
    optVarsList = optVarsHandList + optVarsObjList + optVarsHandJoint
    optVarsListNoJoints = optVarsHandList + optVarsObjList

    # get the initial val of variables for BFGS optimizer
    initVals = []
    for fID in range(len(objParamInitList)):
        initVals.append(handParamInitList[fID].trans)
        initVals.append(handParamInitList[fID].theta[:3])
    initVals.append(handParamInitList[0].theta[handConstrs.validThetaIDs][3:])
    for fID in range(len(objParamInitList)):
        initVals.append(objParamInitList[fID].trans)
        initVals.append(objParamInitList[fID].rot)
    initValsNp = np.concatenate(initVals, axis=0)


    # setup optimizer
    opti1 = Optimizer(totalLoss1, optVarsList, 'Adam', learning_rate=0.02/2.0, initVals=initValsNp)
    opti2 = Optimizer(totalLoss1, optVarsListNoJoints, 'Adam', learning_rate=0.01)
    optiBeta = Optimizer(totalLoss1, optVarsHandBeta, 'Adam', learning_rate=0.05)


    # get the optimization reset ops
    resetOpt1 = tf.variables_initializer(opti1.optimizer.variables())
    resetOpt2 = tf.variables_initializer(opti2.optimizer.variables())

    # tf stuff
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    session = tf.Session(config=config)
    session.__enter__()
    tf.global_variables_initializer().run()

    # setup the plot window
    if FLAGS.showFig:
        plt.ion()
        fig = plt.figure()
        ax = fig.subplots(4, max(numFrames,2))
        axesList = [[],[],[],[]]
        for i in range(numFrames):
            axesList[0].append(ax[0, i].imshow(np.zeros((240,320,3), dtype=np.float32)))
            axesList[1].append(ax[1, i].imshow(np.zeros((240, 320, 3), dtype=np.float32)))
            axesList[2].append(ax[2, i].imshow(np.random.uniform(0,2,(240,320,3))))
            axesList[3].append(ax[3, i].imshow(np.random.uniform(0,1,(240,320,3))))
        plt.subplots_adjust(top=0.984,
                            bottom=0.016,
                            left=0.028,
                            right=0.99,
                            hspace=0.045,
                            wspace=0.124)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
    else:
        plt.ioff()

    # some init runs
    session.run(resetOpt1)
    session.run(resetOpt2)
    # opti1.runOptimization(session, 1, {loadData: True})
    tl = totalLoss1.eval(feed_dict={loadData: True})

    # python renderer for rendering object texture
    configFile = join(HO3D_MULTI_CAMERA_DIR, FLAGS.seq, 'configs/configHandPose.json')
    with open(configFile) as config_file:
        data = yaml.safe_load(config_file)
    modelPath = os.path.join(YCB_MODELS_DIR, data['obj'])
    pyRend = renderScene(h, w)
    pyRend.addObjectFromMeshFile(modelPath, 'obj')
    pyRend.addCamera()
    pyRend.creatcamProjMat(camProp.f, camProp.c, camProp.near, camProp.far)

    segLossList = []
    depLossList = []
    icpLossList = []
    relPoseLossList = []
    repulLossList = []
    joints2DLossList = []


    for i in range(numIter):
        print('iteration ',i)

        thumb13 = cv2.Rodrigues(thetaMat.eval(feed_dict={loadData: False})[0][7])[0]


        if i < 35:
            opti2.runOptimization(session, 1, {loadData: False})  # , logLossFunc=True, lossPlotName='handLoss/'+frameID+'_1.png')
        elif i>350:
            optiBeta.runOptimization(session, 1, {loadData: False})  # , logLossFunc=True, lossPlotName='handLoss/'+frameID+'_1.png')
        else:
            opti1.runOptimization(session, 1, {loadData: False})#, logLossFunc=True, lossPlotName='handLoss/'+frameID+'_1.png')


        segLossList.append(segWt*segLoss.eval(feed_dict={loadData: False}))
        depLossList.append(depWt*depthLoss.eval(feed_dict={loadData: False}))
        icpLossList.append(icpWt*icpLossHand.eval(feed_dict={loadData: False}))
        relPoseLossList.append(1e6*relPoseLoss.eval(feed_dict={loadData: False}))
        repulLossList.append(contactWt*contLoss.eval(feed_dict={loadData: False}))
        if use2DJointLoss:
            joints2DLossList.append(j2dWt*joints2DLoss.eval(feed_dict={loadData: False}))
        # icpLossList.append(1e2*icpLossObj.eval(feed_dict={loadData: False}))

        # show all the images for analysis
        plt.title(str(i))
        depRen = virtObservs.depth.eval(feed_dict={loadData: False})
        depGT = realObservs.depth.eval(feed_dict={loadData: False})
        segRen = virtObservs.seg.eval(feed_dict={loadData: False})
        segGT = realObservs.seg.eval(feed_dict={loadData: False})
        poseMatNp = poseMat.eval(feed_dict={loadData: False})
        colRen = virtObservs.col.eval(feed_dict={loadData: False})
        colGT = realObservs.col.eval(feed_dict={loadData: False})
        frameIDList = (realObservs.frameID.eval(feed_dict={loadData: False}))
        frameIDList = [f.decode('UTF-8') for f in frameIDList]


        for f in range(numFrames):
            if (i % 1 == 0) and FLAGS.showFig:
                frameID = frameIDList[f]
                # frameIDList.append(frameID)
                # render the obj col image
                pyRend.setObjectPose('obj', poseMatNp[f].T)
                if FLAGS.doPyRender:
                    cRend, dRend = pyRend.render()

                # blend with dirt rendered image to get full texture image
                dirtCol = colRen[f][:,:,[2,1,0]]
                objRendMask = (np.sum(np.abs(segRen[f] - objSegColor),2) < 0.05).astype(np.float32)
                objRendMask = np.stack([objRendMask,objRendMask,objRendMask], axis=2)
                if FLAGS.doPyRender:
                    finalCol = dirtCol*(1-objRendMask) + (cRend.astype(np.float32)/255.)*objRendMask

                axesList[0][f].set_data(colGT[f])
                if FLAGS.doPyRender:
                    axesList[1][f].set_data(finalCol)
                axesList[2][f].set_data(np.abs(depRen-depGT)[f,:,:,0])
                axesList[3][f].set_data(np.abs(segRen-segGT)[f,:,:,:])


                coordChangMat = np.array([[1., 0., 0.], [0., -1., 0.], [0., 0., -1.]])
                handJoints = scene.itemPropsDict[handID].transorfmedJs.eval(feed_dict={loadData: False})[f]
                camMat = camProp.getCamMat()
                handJointProj = cv2.projectPoints(handJoints.dot(coordChangMat), np.zeros((3,)), np.zeros((3,)), camMat, np.zeros((4,)))[0][:,0,:]
                imgIn = (colGT[f][:, :, [2, 1, 0]] * 255).astype(np.uint8).copy()
                imgIn = cv2.resize(imgIn, (dscale*imgIn.shape[1], dscale*imgIn.shape[0]), interpolation=cv2.INTER_CUBIC)
                imgJoints = showHandJoints(imgIn, np.round(handJointProj).astype(np.int32)[jointsMapManoToObman]*dscale,
                                           estIn=None, filename=None, upscale=1, lineThickness=2)

                objCorners = getObjectCorners(objMesh.v)
                rotObjNp = rotObj.eval(feed_dict={loadData: False})[f]
                transObjNp = transObj.eval(feed_dict={loadData: False})[f]
                objCornersTrans = np.matmul(objCorners, cv2.Rodrigues(rotObjNp)[0].T) + transObjNp
                objCornersProj = cv2.projectPoints(objCornersTrans.dot(coordChangMat), np.zeros((3,)), np.zeros((3,)), camMat, np.zeros((4,)))[0][:,0, :]
                imgJoints = showObjJoints(imgJoints, objCornersProj*dscale, lineThickness=2)

                mask = np.sum(segRen[f],2)>0
                mask = np.stack([mask, mask, mask], axis=2)

                alpha = 0.35
                rendMask = segRen[f]
                # rendMask[:,:,[1,2]] = 0
                rendMask = np.clip(255. * rendMask, 0, 255).astype('uint8')
                msk = rendMask.sum(axis=2) > 0
                msk = msk * alpha
                msk = np.stack([msk, msk, msk], axis=2)
                blended = msk * rendMask[:,:,[2,1,0]] + (1. - msk) * (colGT[f][:, :, [2, 1, 0]] * 255).astype(np.uint8)
                blended = blended.astype(np.uint8)
                cv2.imwrite(os.path.join(HO3D_MULTI_CAMERA_DIR, FLAGS.seq, 'dirt_grasp_pose', str(f) + '_blend.png'), imgJoints)


        rotObjNp = rotObj.eval(feed_dict={loadData: False})
        transObjNp = transObj.eval(feed_dict={loadData: False})

        if FLAGS.showFig:
            plt.savefig(out_dir + '/'+str(i)+'.png')
            plt.waitforbuttonpress(0.01)

        # dump loss plots intermittently
        if (i%25 == 0 or i == (numIter-1)) and (i>0):
            segLossAll = np.array(segLossList)
            depLossAll = np.array(depLossList)
            icpLossAll = np.array(icpLossList)
            relPoseLossAll = np.array(relPoseLossList)
            repulLossAll = np.array(repulLossList)
            if use2DJointLoss:
                joints2sLossAll = np.array(joints2DLossList)

            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(segLossList))), segLossAll, 'r')
            fig1.savefig(out_dir + '/' + 'plotSeg_' + str(0) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(depLossList))), depLossAll, 'g')
            fig1.savefig(out_dir + '/' + 'plotDep_' + str(0) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(icpLossList))), icpLossAll, 'b')
            fig1.savefig(out_dir + '/' + 'plotIcp_' + str(0) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(relPoseLossList))), relPoseLossAll, 'b')
            fig1.savefig(out_dir + '/' + 'plotRelPose_' + str(0) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(repulLossAll))), repulLossAll, 'b')
            fig1.savefig(out_dir + '/' + 'plotRepul_' + str(0) + '.png')
            plt.close(fig1)
            if use2DJointLoss:
                fig1 = plt.figure(2)
                plt.plot(np.arange(0, (len(joints2sLossAll))), joints2sLossAll, 'b')
                fig1.savefig(out_dir + '/' + 'plotJoints2D_' + str(0) + '.png')
                plt.close(fig1)

        # save all the vars
        relPoseLossNp = relPoseLoss.eval(feed_dict={loadData: False})
        handJointNp = optVarsHandJoint[0].eval()
        optVarListNp = []
        for optVar in optVarsHandDelta:
            optVarListNp.append(optVar.eval())

        thetaMatNp = thetaMat.eval(feed_dict={loadData: False})
        thetaNp = np.reshape(cv2BatchRodrigues(np.reshape(thetaMatNp, [-1,3,3])), [numFrames, 48])
        betaNp = betaHand.eval(feed_dict={loadData: False})
        transNp = transHand.eval(feed_dict={loadData: False})
        rotObjNp = rotObj.eval(feed_dict={loadData: False})
        transObjNp = transObj.eval(feed_dict={loadData: False})
        JTransformed = scene.itemPropsDict[handID].transorfmedJs.eval(feed_dict={loadData: False})
        projPts = np.reshape(cv2ProjectPoints(camProp, np.reshape(JTransformed, [-1, 3])), [numFrames, JTransformed.shape[1], 2])
        # vis = getBatch2DPtVisFromDep(depRen, segRen, projPts, JTransformed, handSegColor)
        savePickleData(out_dir + '/' + 'graspPose'+'.pkl', {'beta':betaNp, 'fullpose': thetaNp, 'trans': transNp,
                                                        'rotObj':rotObjNp, 'transObj': transObjNp,
                                                        'JTransformed': JTransformed,
                                                        'frameID': frameIDList})#, 'JVis': vis})

    finalHandVert = scene.itemPropsDict[handID].transformedMesh.v.eval(feed_dict={loadData: False})
    handFace = scene.itemPropsDict[handID].transformedMesh.f
    finalObjVert = scene.itemPropsDict[objID].transformedMesh.v.eval(feed_dict={loadData: False})
    objFace = scene.itemPropsDict[objID].transformedMesh.f
    finalHandMesh = o3d.geometry.TriangleMesh()
    finalHandMesh.vertices = o3d.utility.Vector3dVector(finalHandVert[0][:,:3])
    finalHandMesh.triangles = o3d.utility.Vector3iVector(handFace)
    finalHandMesh.vertex_colors = o3d.utility.Vector3dVector(np.reshape(np.random.uniform(0., 1., finalHandVert.shape[1]*3), (finalHandVert.shape[1],3)))
    finalObjMesh = o3d.geometry.TriangleMesh()
    finalObjMesh.vertices = o3d.utility.Vector3dVector(finalObjVert[0][:,:3])
    finalObjMesh.triangles = o3d.utility.Vector3iVector(objFace)
    finalObjMesh.vertex_colors = o3d.utility.Vector3dVector(
        np.reshape(np.random.uniform(0., 1., finalObjVert.shape[1] * 3), (finalObjVert.shape[1], 3)))

    o3d.io.write_triangle_mesh(out_dir+'/'+'hand.ply', finalHandMesh)
    o3d.io.write_triangle_mesh(out_dir + '/' + 'object.ply', finalObjMesh)
    vis = o3d.visualization.Visualizer()
    vis.create_window(window_name='Open3D', width=640, height=480, left=0, top=0,
                      visible=True)  # use visible=True to visualize the point cloud
    vis.get_render_option().light_on = False
    vis.add_geometry(finalHandMesh)
    vis.add_geometry(finalObjMesh)
    vis.run()

    return
def handObjectTrack(w, h, objParamInit, handParamInit, objMesh, camProp,
                    out_dir):
    ds = tf.data.Dataset.from_generator(
        lambda: dataGen(w, h),
        (tf.string, tf.float32, tf.float32, tf.float32, tf.float32),
        ((None, ), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3),
         (None, h, w, 3)))
    # assert len(objParamInitList)==len(handParamInitList)

    numFrames = 1

    # read real observations
    frameCntInt, loadData, realObservs = LossObservs.getRealObservables(
        ds, numFrames, w, h)
    icp = Icp(realObservs, camProp)

    # set up the scene
    scene = Scene(optModeEnum.MULTIFRAME_RIGID_HO_POSE, frameCnt=numFrames)
    objID = scene.addObject(objMesh, objParamInit, segColor=objSegColor)
    handID = scene.addHand(handParamInit, handSegColor, baseItemID=objID)
    scene.addCamera(f=camProp.f,
                    c=camProp.c,
                    near=camProp.near,
                    far=camProp.far,
                    frameSize=camProp.frameSize)
    finalMesh = scene.getFinalMesh()

    # render the scene
    renderer = DirtRenderer(finalMesh, renderModeEnum.SEG_COLOR_DEPTH)
    virtObservs = renderer.render()

    # get loss over observables
    observLoss = LossObservs(virtObservs, realObservs,
                             renderModeEnum.SEG_COLOR_DEPTH)
    segLoss, depthLoss, colLoss = observLoss.getL2Loss(isClipDepthLoss=True,
                                                       pyrLevel=2)

    # get parameters and constraints
    handConstrs = Constraints()
    paramListHand = scene.getVarsByItemID(
        handID, [varTypeEnum.HAND_JOINT, varTypeEnum.HAND_ROT])
    jointAngs = paramListHand[0]
    handRot = paramListHand[1]
    validTheta = tf.concat([handRot, jointAngs], axis=0)
    theta = handConstrs.getFullThetafromValidTheta(validTheta)
    thetaConstrs, _ = handConstrs.getHandThetaConstraints(validTheta,
                                                          isValidTheta=True)

    paramListObj = scene.getParamsByItemID(
        [parTypeEnum.OBJ_ROT, parTypeEnum.OBJ_TRANS, parTypeEnum.OBJ_POSE_MAT],
        objID)
    rotObj = paramListObj[0]
    transObj = paramListObj[1]
    poseMat = paramListObj[2]

    paramListHand = scene.getParamsByItemID([
        parTypeEnum.HAND_THETA, parTypeEnum.HAND_TRANS, parTypeEnum.HAND_BETA
    ], handID)
    thetaMat = paramListHand[0]
    transHand = paramListHand[1]
    betaHand = paramListHand[2]

    # get icp losses
    icpLossHand = icp.getLoss(scene.itemPropsDict[handID].transformedMesh.v,
                              handSegColor)
    icpLossObj = icp.getLoss(scene.itemPropsDict[objID].transformedMesh.v,
                             objSegColor)

    # get rel hand obj pose loss
    handTransVars = tf.stack(scene.getVarsByItemID(
        handID, [varTypeEnum.HAND_TRANS_REL_DELTA]),
                             axis=0)
    handRotVars = tf.stack(scene.getVarsByItemID(
        handID, [varTypeEnum.HAND_ROT_REL_DELTA]),
                           axis=0)
    relPoseLoss = handConstrs.getHandObjRelDeltaPoseConstraint(
        handRotVars, handTransVars)

    # get final loss
    icpLoss = 1e3 * icpLossHand + 1e3 * icpLossObj
    totalLoss1 = 1.0e1 * segLoss + 1e0 * depthLoss + 0.0 * colLoss + 1e2 * thetaConstrs + icpLoss + 1e6 * relPoseLoss
    totalLoss2 = 1.15 * segLoss + 5.0 * depthLoss + 0.0 * colLoss + 1e2 * thetaConstrs

    # get the variables for opt
    optVarsHandList = scene.getVarsByItemID(
        handID,
        [  #varTypeEnum.HAND_TRANS, varTypeEnum.HAND_ROT,
            # varTypeEnum.HAND_ROT_REL_DELTA, varTypeEnum.HAND_TRANS_REL_DELTA,
            varTypeEnum.HAND_JOINT
        ],
        [])
    optVarsHandDelta = scene.getVarsByItemID(
        handID,
        [varTypeEnum.HAND_TRANS_REL_DELTA, varTypeEnum.HAND_ROT_REL_DELTA], [])
    optVarsHandJoint = scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT],
                                             [])
    optVarsObjList = scene.getVarsByItemID(
        objID, [varTypeEnum.OBJ_TRANS, varTypeEnum.OBJ_ROT], [])
    optVarsList = optVarsObjList  #+ optVarsHandList
    optVarsListNoJoints = optVarsObjList  #+ optVarsHandList

    # get the initial val of variables for BFGS optimizer
    initVals = []
    for fID in range(len(objParamInitList)):
        initVals.append(handParamInitList[fID].trans)
        initVals.append(handParamInitList[fID].theta[:3])
    initVals.append(handParamInitList[0].theta[handConstrs.validThetaIDs][3:])
    for fID in range(len(objParamInitList)):
        initVals.append(objParamInitList[fID].trans)
        initVals.append(objParamInitList[fID].rot)
    initValsNp = np.concatenate(initVals, axis=0)

    # setup optimizer
    opti1 = Optimizer(totalLoss1,
                      optVarsList,
                      'Adam',
                      learning_rate=0.02 / 2.0,
                      initVals=initValsNp)
    opti2 = Optimizer(totalLoss1,
                      optVarsListNoJoints,
                      'Adam',
                      learning_rate=0.01)

    # get the optimization reset ops
    resetOpt1 = tf.variables_initializer(opti1.optimizer.variables())
    resetOpt2 = tf.variables_initializer(opti2.optimizer.variables())

    # tf stuff
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    session = tf.Session(config=config)
    session.__enter__()
    tf.global_variables_initializer().run()

    # setup the plot window
    if showFig:
        plt.ion()
        fig = plt.figure()
        ax = fig.subplots(4, max(numFrames, 1))
        axesList = [[], [], [], []]
        for i in range(numFrames):
            axesList[0].append(ax[0].imshow(
                np.zeros((240, 320, 3), dtype=np.float32)))
            axesList[1].append(ax[1].imshow(
                np.zeros((240, 320, 3), dtype=np.float32)))
            axesList[2].append(ax[2].imshow(
                np.random.uniform(0, 2, (240, 320, 3))))
            axesList[3].append(ax[3].imshow(
                np.random.uniform(0, 1, (240, 320, 3))))
        plt.subplots_adjust(top=0.984,
                            bottom=0.016,
                            left=0.028,
                            right=0.99,
                            hspace=0.045,
                            wspace=0.124)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()

    # python renderer for rendering object texture
    pyRend = renderScene(h, w)
    pyRend.addObjectFromMeshFile(modelPath, 'obj')
    pyRend.addCamera()
    pyRend.creatcamProjMat(camProp.f, camProp.c, camProp.near, camProp.far)

    segLossList = []
    depLossList = []
    icpLossList = []
    relPoseLossList = []

    while (True):
        session.run(resetOpt1)
        session.run(resetOpt2)

        # load new frame
        opti1.runOptimization(session, 1, {loadData: True})
        # print(icpLoss.eval(feed_dict={loadData: False}))
        # print(segLoss.eval(feed_dict={loadData: False}))
        # print(depthLoss.eval(feed_dict={loadData: False}))

        # run the optimization for new frame
        frameID = (realObservs.frameID.eval(
            feed_dict={loadData: False}))[0].decode('UTF-8')
        opti1.runOptimization(session, FLAGS.numIter, {loadData: False})

        segLossList.append(1.0 * segLoss.eval(feed_dict={loadData: False}))
        depLossList.append(1.0 * depthLoss.eval(feed_dict={loadData: False}))
        icpLossList.append(icpLoss.eval(feed_dict={loadData: False}))
        relPoseLossList.append(1e3 *
                               relPoseLoss.eval(feed_dict={loadData: False}))
        # icpLossList.append(1e2*icpLossObj.eval(feed_dict={loadData: False}))

        # show all the images for analysis
        plt.title(frameID)
        depRen = virtObservs.depth.eval(feed_dict={loadData: False})
        depGT = realObservs.depth.eval(feed_dict={loadData: False})
        segRen = virtObservs.seg.eval(feed_dict={loadData: False})
        segGT = realObservs.seg.eval(feed_dict={loadData: False})
        poseMatNp = poseMat.eval(feed_dict={loadData: False})
        colRen = virtObservs.col.eval(feed_dict={loadData: False})
        colGT = realObservs.col.eval(feed_dict={loadData: False})
        for f in range(numFrames):
            if doPyRendFinalImage:
                # render the obj col image
                pyRend.setObjectPose('obj', poseMatNp[f].T)
                cRend, dRend = pyRend.render()

                # blend with dirt rendered image to get full texture image
                dirtCol = colRen[f][:, :, [2, 1, 0]]
                objRendMask = (np.sum(np.abs(segRen[f] - objSegColor), 2) <
                               0.05).astype(np.float32)
                objRendMask = np.stack([objRendMask, objRendMask, objRendMask],
                                       axis=2)
                finalCol = dirtCol * (1 - objRendMask) + (
                    cRend.astype(np.float32) / 255.) * objRendMask

            if showFig:
                axesList[0][f].set_data(colGT[f])
                if doPyRendFinalImage:
                    axesList[1][f].set_data(finalCol)
                axesList[2][f].set_data(np.abs(depRen - depGT)[f, :, :, 0])
                axesList[3][f].set_data(np.abs(segRen - segGT)[f, :, :, :])

            if f >= 0:
                coordChangMat = np.array([[1., 0., 0.], [0., -1., 0.],
                                          [0., 0., -1.]])
                handJoints = scene.itemPropsDict[handID].transorfmedJs.eval(
                    feed_dict={loadData: False})[f]
                camMat = camProp.getCamMat()
                handJointProj = cv2.projectPoints(
                    handJoints.dot(coordChangMat), np.zeros((3, )),
                    np.zeros((3, )), camMat, np.zeros((4, )))[0][:, 0, :]
                imgIn = (colGT[f][:, :, [2, 1, 0]] * 255).astype(
                    np.uint8).copy()
                imgIn = cv2.resize(
                    imgIn, (imgIn.shape[1] * dscale, imgIn.shape[0] * dscale),
                    interpolation=cv2.INTER_LANCZOS4)
                imgJoints = showHandJoints(
                    imgIn,
                    np.round(handJointProj).astype(
                        np.int32)[jointsMapManoToObman] * dscale,
                    estIn=None,
                    filename=None,
                    upscale=1,
                    lineThickness=2)

                objCorners = getObjectCorners(mesh.v)
                rotObjNp = rotObj.eval(feed_dict={loadData: False})[f]
                transObjNp = transObj.eval(feed_dict={loadData: False})[f]
                objCornersTrans = np.matmul(
                    objCorners,
                    cv2.Rodrigues(rotObjNp)[0].T) + transObjNp
                objCornersProj = cv2.projectPoints(
                    objCornersTrans.dot(coordChangMat), np.zeros((3, )),
                    np.zeros((3, )), camMat, np.zeros((4, )))[0][:, 0, :]
                imgJoints = showObjJoints(imgJoints,
                                          objCornersProj * dscale,
                                          lineThickness=2)

                alpha = 0.35
                rendMask = segRen[f]
                # rendMask[:,:,[1,2]] = 0
                rendMask = np.clip(255. * rendMask, 0, 255).astype('uint8')
                msk = rendMask.sum(axis=2) > 0
                msk = msk * alpha
                msk = np.stack([msk, msk, msk], axis=2)
                blended = msk * rendMask[:, :, [2, 1, 0]] + (1. - msk) * (
                    colGT[f][:, :, [2, 1, 0]] * 255).astype(np.uint8)
                blended = blended.astype(np.uint8)

                cv2.imwrite(out_dir + '/annoVis_' + frameID + '.jpg',
                            imgJoints)
                cv2.imwrite(out_dir + '/annoBlend_' + frameID + '.jpg',
                            blended)
                cv2.imwrite(out_dir + '/maskOnly_' + frameID + '.jpg',
                            (segRen[0] * 255).astype(np.uint8))
                depthEnc = encodeDepthImg(depRen[0, :, :, 0])
                cv2.imwrite(out_dir + '/renderDepth_' + frameID + '.jpg',
                            depthEnc)
                if doPyRendFinalImage:
                    cv2.imwrite(out_dir + '/renderCol_' + frameID + '.jpg',
                                (finalCol[:, :, [2, 1, 0]] * 255).astype(
                                    np.uint8))

        if showFig:
            plt.savefig(out_dir + '/' + frameID + '.png')
            plt.waitforbuttonpress(0.01)

        # save all the vars
        optVarListNp = []
        for optVar in optVarsHandDelta:
            optVarListNp.append(optVar.eval())

        thetaNp = thetaMat.eval(feed_dict={loadData: False})[0]
        betaNp = betaHand.eval(feed_dict={loadData: False})[0]
        transNp = transHand.eval(feed_dict={loadData: False})[0]
        rotObjNp = rotObj.eval(feed_dict={loadData: False})[0]
        transObjNp = transObj.eval(feed_dict={loadData: False})[0]
        JTransformed = scene.itemPropsDict[handID].transorfmedJs.eval(
            feed_dict={loadData: False})
        handJproj = np.reshape(
            cv2ProjectPoints(camProp, np.reshape(JTransformed, [-1, 3])),
            [numFrames, JTransformed.shape[1], 2])
        # vis = getBatch2DPtVisFromDep(depRen, segRen, projPts, JTransformed, handSegColor)
        objCornersRest = np.load(
            os.path.join(YCB_OBJECT_CORNERS_DIR,
                         obj.split('/')[0], 'corners.npy'))
        objCornersTransormed = objCornersRest.dot(
            cv2.Rodrigues(rotObjNp)[0].T) + transObjNp
        objCornersproj = np.reshape(
            cv2ProjectPoints(camProp, np.reshape(objCornersTransormed,
                                                 [-1, 3])),
            [objCornersTransormed.shape[0], 2])

        savePickleData(
            out_dir + '/' + frameID + '.pkl', {
                'beta': betaNp,
                'fullpose': thetaNp,
                'trans': transNp,
                'rotObj': rotObjNp,
                'transObj': transObjNp,
                'JTransformed': JTransformed,
                'objCornersRest': objCornersRest,
                'objCornersTransormed': objCornersTransormed,
                'objName': obj.split('/')[0],
                'objLabel': objLabel
            })
def handPoseMF(w, h, objParamInitList, handParamInitList, objMesh, camProp, out_dir):
    ds = tf.data.Dataset.from_generator(lambda: dataGen(w, h, batchSize),
                                        (tf.string, tf.float32, tf.float32, tf.float32, tf.float32),
                                        ((None,), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3)))

    dsVarInit = tf.data.Dataset.from_generator(lambda: initVarGen(batchSize),
                                        (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32),
                                        ((batchSize, 48), (batchSize, 3), (batchSize, 10), (batchSize, 3), (batchSize, 3)))

    assert len(objParamInitList)==len(handParamInitList)

    numFrames = len(objParamInitList)

    # read real observations
    frameCntInt, loadData, realObservs = LossObservs.getRealObservables(ds, numFrames, w, h)
    icp = Icp(realObservs, camProp)

    # set up the scene
    scene = Scene(optModeEnum.MULTIFRAME_RIGID_HO_POSE_JOINT, frameCnt=numFrames)
    objID = scene.addObject(objMesh, objParamInitList, segColor=objSegColor)
    handID = scene.addHand(handParamInitList, handSegColor, baseItemID=objID)
    scene.addCamera(f=camProp.f, c=camProp.c, near=camProp.near, far=camProp.far, frameSize=camProp.frameSize)
    finalMesh = scene.getFinalMesh()

    # render the scene
    renderer = DirtRenderer(finalMesh, renderModeEnum.SEG_COLOR_DEPTH)
    virtObservs = renderer.render()

    # get loss over observables
    observLoss = LossObservs(virtObservs, realObservs, renderModeEnum.SEG_COLOR_DEPTH)
    segLoss, depthLoss, colLoss = observLoss.getL2Loss(isClipDepthLoss=True, pyrLevel=2)

    # get parameters and constraints
    handConstrs = Constraints()
    paramListHand = scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT, varTypeEnum.HAND_ROT])
    jointAngs = paramListHand[0]
    handRot = paramListHand[1]
    validTheta = tf.concat([handRot, jointAngs], axis=0)
    theta = handConstrs.getFullThetafromValidTheta(validTheta)
    thetaConstrs, _ = handConstrs.getHandThetaConstraints(validTheta, isValidTheta=True)

    # some variables for vis and analysis
    paramListObj = scene.getParamsByItemID([parTypeEnum.OBJ_ROT, parTypeEnum.OBJ_TRANS, parTypeEnum.OBJ_POSE_MAT], objID)
    rotObj = paramListObj[0]
    transObj = paramListObj[1]
    poseMat = paramListObj[2]

    paramListHand = scene.getParamsByItemID([parTypeEnum.HAND_THETA, parTypeEnum.HAND_TRANS, parTypeEnum.HAND_BETA],
                                           handID)
    thetaMat = paramListHand[0]
    transHand = paramListHand[1]
    betaHand = paramListHand[2]


    # contact loss
    hoContact = ContactLoss(scene.itemPropsDict[objID].transformedMesh, scene.itemPropsDict[handID].transformedMesh,
                            scene.itemPropsDict[handID].transorfmedJs)
    contLoss = hoContact.getRepulsionLoss()

    # get icp losses
    icpLossHand = icp.getLoss(scene.itemPropsDict[handID].transformedMesh.v, handSegColor)
    icpLossObj = icp.getLoss(scene.itemPropsDict[objID].transformedMesh.v, objSegColor)

    # get rel hand obj pose loss
    handTransVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS_REL_DELTA]), axis=0)
    handRotVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_ROT_REL_DELTA]), axis=0)
    relPoseLoss = handConstrs.getHandObjRelDeltaPoseConstraint(handRotVars, handTransVars)

    # get temporal loss
    handJointVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT]), axis=0)
    handRotVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_ROT]), axis=0)
    handTransVars = tf.stack(scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS]), axis=0)
    objRotVars = tf.stack(scene.getVarsByItemID(objID, [varTypeEnum.OBJ_ROT]), axis=0)
    objTransVars = tf.stack(scene.getVarsByItemID(objID, [varTypeEnum.OBJ_TRANS]), axis=0)

    handJointsTempLoss = handConstrs.getTemporalConstraint(handJointVars, type='ZERO_ACCL')
    objRotTempLoss = handConstrs.getTemporalConstraint(objRotVars, type='ZERO_ACCL')
    handRotTempLoss = handConstrs.getTemporalConstraint(handRotVars, type='ZERO_ACCL')
    objTransTempLoss = handConstrs.getTemporalConstraint(objTransVars, type='ZERO_VEL')
    handTransTempLoss = handConstrs.getTemporalConstraint(handTransVars, type='ZERO_VEL')

    # get final loss
    segWt = 10.0
    depWt = 5.0
    colWt = 0.0
    thetaConstWt = 1e2
    icpHandWt = 1e2
    icpObjWt = 1e2
    relPoseWt  = 1e2
    contactWt = 0#1e-2

    handJointsTempWt = 1e1
    objRotTempWt = 1e1
    handRotTempWt = 0.
    objTransTempWt = 5e2
    handTransTempWt = 5e1
    totalLoss1 = segWt*segLoss + depWt*depthLoss + colWt*colLoss + thetaConstWt*thetaConstrs + icpHandWt*icpLossHand + icpObjWt*icpLossObj + \
                 relPoseWt*relPoseLoss + contactWt*contLoss + handJointsTempWt*handJointsTempLoss + \
                 objRotTempWt*objRotTempLoss + objTransTempWt*objTransTempLoss# + handRotTempWt*handRotTempLoss + handTransTempWt*handTransTempLoss
    totalLoss2 = 1.15 * segLoss + 5.0 * depthLoss + 0.0*colLoss + 1e2 * thetaConstrs

    # get the variables for opt
    optVarsHandList = scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS, varTypeEnum.HAND_ROT,
                                                     #varTypeEnum.HAND_ROT_REL_DELTA, varTypeEnum.HAND_TRANS_REL_DELTA,
                                                     varTypeEnum.HAND_JOINT], [])
    optVarsHandDelta = scene.getVarsByItemID(handID, [varTypeEnum.HAND_TRANS_REL_DELTA, varTypeEnum.HAND_ROT_REL_DELTA], [])
    optVarsHandJoint = scene.getVarsByItemID(handID, [varTypeEnum.HAND_JOINT], [])
    optVarsObjList = scene.getVarsByItemID(objID, [varTypeEnum.OBJ_TRANS, varTypeEnum.OBJ_ROT], [])
    optVarsList = optVarsHandList + optVarsObjList
    optVarsListNoJoints = optVarsHandList + optVarsObjList

    # get var init op
    initOpList = getVarInitsOpJoints(dsVarInit, scene, handID, objID, numFrames)

    # get the initial val of variables for BFGS optimizer
    initVals = []
    for fID in range(len(objParamInitList)):
        initVals.append(handParamInitList[fID].trans)
        initVals.append(handParamInitList[fID].theta[:3])
    initVals.append(handParamInitList[0].theta[handConstrs.validThetaIDs][3:])
    for fID in range(len(objParamInitList)):
        initVals.append(objParamInitList[fID].trans)
        initVals.append(objParamInitList[fID].rot)
    initValsNp = np.concatenate(initVals, axis=0)


    # setup optimizer
    opti1 = Optimizer(totalLoss1, optVarsList, 'Adam', learning_rate=1.0*0.02/2.0, initVals=initValsNp)
    opti2 = Optimizer(totalLoss1, optVarsListNoJoints, 'Adam', learning_rate=0.01)

    # get the optimization reset ops
    resetOpt1 = tf.variables_initializer(opti1.optimizer.variables())
    resetOpt2 = tf.variables_initializer(opti2.optimizer.variables())

    # tf stuff
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    session = tf.Session(config=config)
    session.__enter__()
    tf.global_variables_initializer().run()

    # setup the plot window

    if showFigForIter:
        plt.ion()
        fig = plt.figure()
        ax = fig.subplots(4, max(numFrames,2))
        axesList = [[],[],[],[]]
        for i in range(numFrames):
            axesList[0].append(ax[0, i].imshow(np.zeros((240,320,3), dtype=np.float32)))
            axesList[1].append(ax[1, i].imshow(np.zeros((240, 320, 3), dtype=np.float32)))
            axesList[2].append(ax[2, i].imshow(np.random.uniform(0,2,(240,320,3))))
            axesList[3].append(ax[3, i].imshow(np.random.uniform(0,1,(240,320,3))))
        plt.subplots_adjust(top=0.984,
                            bottom=0.016,
                            left=0.028,
                            right=0.99,
                            hspace=0.045,
                            wspace=0.124)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
    else:
        plt.ioff()

    # some init runs
    session.run(resetOpt1)
    session.run(resetOpt2)
    # opti1.runOptimization(session, 1, {loadData: True})
    # tl = totalLoss1.eval(feed_dict={loadData: True})

    # python renderer for rendering object texture
    pyRend = renderScene(h, w)
    pyRend.addObjectFromMeshFile(modelPath, 'obj')
    pyRend.addCamera()
    pyRend.creatcamProjMat(camProp.f, camProp.c, camProp.near, camProp.far)

    while True:
        segLossList = []
        depLossList = []
        icpLossList = []
        relPoseLossList = []
        repulLossList = []
        jointTempLossList = []
        objRotTempLossList = []
        objTransTempLossList = []

        # load the init values for variables
        session.run(initOpList)

        # load the real observations
        tl = totalLoss1.eval(feed_dict={loadData: True})
        for i in range(numIter):
            print('iteration ',i)

            thumb13 = cv2.Rodrigues(thetaMat.eval(feed_dict={loadData: False})[0][7])[0]

            # run the optimization for new frame
            frameID = (realObservs.frameID.eval(feed_dict={loadData: False}))[0].decode('UTF-8')
            iterDir = join(out_dir, frameID)
            if not os.path.exists(iterDir):
                os.mkdir(iterDir)
            if i < 0:
                opti2.runOptimization(session, 1,
                                      {loadData: False})  # , logLossFunc=True, lossPlotName='handLoss/'+frameID+'_1.png')
            else:
                opti1.runOptimization(session, 1, {loadData: False})#, logLossFunc=True, lossPlotName='handLoss/'+frameID+'_1.png')


            segLossList.append(segWt*segLoss.eval(feed_dict={loadData: False}))
            depLossList.append(depWt*depthLoss.eval(feed_dict={loadData: False}))
            icpLossList.append(icpHandWt*icpLossHand.eval(feed_dict={loadData: False}))
            relPoseLossList.append(relPoseWt*relPoseLoss.eval(feed_dict={loadData: False}))
            repulLossList.append(contactWt*contLoss.eval(feed_dict={loadData: False}))
            jointTempLossList.append(handJointsTempWt*handJointsTempLoss.eval(feed_dict={loadData: False}))
            objRotTempLossList.append(objRotTempWt * objRotTempLoss.eval(feed_dict={loadData: False}))
            objTransTempLossList.append(objTransTempWt * objTransTempLoss.eval(feed_dict={loadData: False}))
            # icpLossList.append(1e2*icpLossObj.eval(feed_dict={loadData: False}))

            # show all the images for analysis

            plt.title(str(i))
            depRen = virtObservs.depth.eval(feed_dict={loadData: False})
            depGT = realObservs.depth.eval(feed_dict={loadData: False})
            segRen = virtObservs.seg.eval(feed_dict={loadData: False})
            segGT = realObservs.seg.eval(feed_dict={loadData: False})
            poseMatNp = poseMat.eval(feed_dict={loadData: False})
            colRen = virtObservs.col.eval(feed_dict={loadData: False})
            colGT = realObservs.col.eval(feed_dict={loadData: False})
            finalCol = np.zeros_like(colRen)
            for f in range(numFrames):


                if doPyRender:
                    # render the obj col image
                    pyRend.setObjectPose('obj', poseMatNp[f].T)
                    cRend, dRend = pyRend.render()
                    # blend with dirt rendered image to get full texture image
                    dirtCol = colRen[f][:,:,[2,1,0]]
                    objRendMask = (np.sum(np.abs(segRen[f] - objSegColor),2) < 0.05).astype(np.float32)
                    objRendMask = np.stack([objRendMask,objRendMask,objRendMask], axis=2)
                    finalCol[f] = dirtCol*(1-objRendMask) + (cRend.astype(np.float32)/255.)*objRendMask

                if showFigForIter:
                    axesList[0][f].set_data(colGT[f])
                    if doPyRender:
                        axesList[1][f].set_data(finalCol[f])
                    axesList[2][f].set_data(np.abs(depRen-depGT)[f,:,:,0])
                    axesList[3][f].set_data(np.abs(segRen-segGT)[f,:,:,:])


            if showFigForIter:
                plt.savefig(iterDir + '/'+frameID+'_'+str(i)+'.png')
                plt.waitforbuttonpress(0.01)

        frameID = (realObservs.frameID.eval(feed_dict={loadData: False}))  # [0].decode('UTF-8')
        frameID = [f.decode('UTF-8') for f in frameID]
        print(frameID)
        for f in range(numFrames):
            coordChangMat = np.array([[1., 0., 0.], [0., -1., 0.], [0., 0., -1.]])
            handJoints = scene.itemPropsDict[handID].transorfmedJs.eval(feed_dict={loadData: False})[f]
            camMat = camProp.getCamMat()
            handJointProj = \
            cv2.projectPoints(handJoints.dot(coordChangMat), np.zeros((3,)), np.zeros((3,)), camMat, np.zeros((4,)))[0][
            :, 0, :]
            imgIn = (colGT[f][:, :, [2, 1, 0]] * 255).astype(np.uint8).copy()
            imgIn = cv2.resize(imgIn, (imgIn.shape[1] * dscale, imgIn.shape[0] * dscale),
                               interpolation=cv2.INTER_LANCZOS4)
            imgIn = cv2.imread(join(base_dir, 'rgb', camID, frameID[f] + '.png'))
            imgJoints = showHandJoints(imgIn,
                                       np.round(handJointProj).astype(np.int32)[jointsMapManoToObman] * dscale,
                                       estIn=None, filename=None, upscale=1, lineThickness=2)

            objCorners = getObjectCorners(mesh.v)
            rotObjNp = rotObj.eval(feed_dict={loadData: False})[f]
            transObjNp = transObj.eval(feed_dict={loadData: False})[f]
            objCornersTrans = np.matmul(objCorners, cv2.Rodrigues(rotObjNp)[0].T) + transObjNp
            objCornersProj = \
            cv2.projectPoints(objCornersTrans.dot(coordChangMat), np.zeros((3,)), np.zeros((3,)), camMat,
                              np.zeros((4,)))[0][:, 0, :]
            imgJoints = showObjJoints(imgJoints, objCornersProj * dscale, lineThickness=2)

            #bg = cv2.imread('/home/shreyas/Desktop/checkCrop.jpg')
            #bg = cv2.resize(bg, (320, 240))
            #mask = np.sum(segRen[f], 2) > 0
            #mask = np.stack([mask, mask, mask], axis=2)
            # newImg = (finalCol[f, :, :, [2, 1, 0]] * 255).astype(np.uint8) * mask + bg * (1 - mask)

            alpha = 0.35
            rendMask = segRen[f]
            # rendMask[:,:,[1,2]] = 0
            rendMask = np.clip(255. * rendMask, 0, 255).astype('uint8')
            msk = rendMask.sum(axis=2) > 0
            msk = msk * alpha
            msk = np.stack([msk, msk, msk], axis=2)
            blended = msk * rendMask[:, :, [2, 1, 0]] + (1. - msk) * (colGT[f][:, :, [2, 1, 0]] * 255).astype(np.uint8)
            blended = blended.astype(np.uint8)

            # cv2.imwrite(base_dir+'/' + str(f) + '_blend.png', imgJoints)
            cv2.imwrite(out_dir + '/annoVis_' + frameID[f] + '.jpg', imgJoints)
            cv2.imwrite(out_dir + '/annoBlend_' + frameID[f] + '.jpg', blended)
            cv2.imwrite(out_dir + '/maskOnly_' + frameID[f] + '.jpg', (segRen[f] * 255).astype(np.uint8))
            depthEnc = encodeDepthImg(depRen[f, :, :, 0])
            cv2.imwrite(out_dir + '/renderDepth_' + frameID[f] + '.jpg', depthEnc)
            if doPyRender:
                cv2.imwrite(out_dir + '/renderCol_' + frameID[f] + '.jpg',
                            (finalCol[f][:, :, [2, 1, 0]] * 255).astype(np.uint8))

        # dump loss plots intermittently
        if True:
            segLossAll = np.array(segLossList)
            depLossAll = np.array(depLossList)
            icpLossAll = np.array(icpLossList)
            relPoseLossAll = np.array(relPoseLossList)
            repulLossAll = np.array(repulLossList)
            jointTempLossAll = np.array(jointTempLossList)
            objRotTempLossAll = np.array(objRotTempLossList)
            objTransTempLossAll = np.array(objTransTempLossList)

            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(segLossList))), segLossAll, 'r')
            fig1.savefig(iterDir + '/' + 'plotSeg_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(depLossList))), depLossAll, 'g')
            fig1.savefig(iterDir + '/' + 'plotDep_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(icpLossList))), icpLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotIcp_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(relPoseLossList))), relPoseLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotRelPose_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(repulLossAll))), repulLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotRepul_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(jointTempLossAll))), jointTempLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotJointTemp_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(objRotTempLossAll))), objRotTempLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotObjRotTemp_%s'%(frameID[0]) + '.png')
            plt.close(fig1)
            fig1 = plt.figure(2)
            plt.plot(np.arange(0, (len(objTransTempLossAll))), objTransTempLossAll, 'b')
            fig1.savefig(iterDir + '/' + 'plotObjTransTemp_%s'%(frameID[0]) + '.png')
            plt.close(fig1)

        # save all the vars
        relPoseLossNp = relPoseLoss.eval(feed_dict={loadData: False})
        handJointNp = optVarsHandJoint[0].eval()
        optVarListNp = []
        for optVar in optVarsHandDelta:
            optVarListNp.append(optVar.eval())


        thetaMatNp = thetaMat.eval(feed_dict={loadData: False})
        thetaNp = np.reshape(cv2BatchRodrigues(np.reshape(thetaMatNp, [-1,3,3])), [numFrames, 48])
        betaNp = betaHand.eval(feed_dict={loadData: False})
        transNp = transHand.eval(feed_dict={loadData: False})
        rotObjNp = rotObj.eval(feed_dict={loadData: False})
        transObjNp = transObj.eval(feed_dict={loadData: False})
        JTransformed = scene.itemPropsDict[handID].transorfmedJs.eval(feed_dict={loadData: False})
        projPts = np.reshape(cv2ProjectPoints(camProp, np.reshape(JTransformed, [-1, 3])), [numFrames, JTransformed.shape[1], 2])
        # vis = getBatch2DPtVisFromDep(depRen, segRen, projPts, JTransformed, handSegColor)
        for f in range(numFrames):
            objCornersRest = np.load(os.path.join(YCB_OBJECT_CORNERS_DIR, obj.split('/')[0], 'corners.npy'))
            objCornersTransormed = objCornersRest.dot(cv2.Rodrigues(rotObjNp[f])[0].T) + transObjNp[f]
            savePickleData(out_dir + '/' + frameID[f] + '.pkl', {'beta': betaNp[f], 'fullpose': thetaNp[f], 'trans': transNp[f],
                                                              'rotObj': rotObjNp[f], 'transObj': transObjNp[f],
                                                              'JTransformed': JTransformed[f],
                                                              'objCornersRest': objCornersRest,
                                                              'objCornersTransormed': objCornersTransormed,
                                                              'objName': obj.split('/')[0], 'objLabel': objLabel})
Example #4
0
def objectTracker(w, h, paramInit, camProp, objMesh, out_dir, configData):
    '''
    Generative object tracking
    :param w: width of the image
    :param h: height of the image
    :param paramInit: object of objParams class
    :param camProp: camera properties object
    :param objMesh: object mesh
    :param out_dir: out directory
    :return:
    '''
    ds = tf.data.Dataset.from_generator(
        lambda: dataGen(w, h, datasetName),
        (tf.string, tf.float32, tf.float32, tf.float32, tf.float32),
        ((None, ), (None, h, w, 3), (None, h, w, 3), (None, h, w, 3),
         (None, h, w, 3)))
    numFrames = 1

    # read real observations
    frameCntInt, loadData, realObservs = LossObservs.getRealObservables(
        ds, numFrames, w, h)
    icp = Icp(realObservs, camProp)

    # set up the scene
    scene = Scene(optModeEnum.MULTIFRAME_JOINT, frameCnt=1)
    objID = scene.addObject(objMesh,
                            paramInit,
                            segColor=np.array([1., 1., 1.]))
    scene.addCamera(f=camProp.f,
                    c=camProp.c,
                    near=camProp.near,
                    far=camProp.far,
                    frameSize=camProp.frameSize)
    finalMesh = scene.getFinalMesh()

    # render the scene
    renderer = DirtRenderer(finalMesh, renderModeEnum.SEG_DEPTH)
    virtObservs = renderer.render()

    # get loss over observables
    observLoss = LossObservs(virtObservs, realObservs,
                             renderModeEnum.SEG_DEPTH)
    segLoss, depthLoss, _ = observLoss.getL2Loss(isClipDepthLoss=True,
                                                 pyrLevel=2)

    # get constraints
    handConstrs = Constraints()
    paramList = scene.getParamsByItemID(
        [parTypeEnum.OBJ_ROT, parTypeEnum.OBJ_TRANS, parTypeEnum.OBJ_POSE_MAT],
        objID)
    rot = paramList[0]
    trans = paramList[1]
    poseMat = paramList[2]

    # get icp loss
    icpLoss = icp.getLoss(finalMesh.vUnClipped)

    # get final loss
    objImg = (realObservs.col)
    # totalLoss1 = 1.0*segLoss + 1e1*depthLoss + 1e4*icpLoss + 0.0*tf.reduce_sum(objImg-virtObservs.seg)
    totalLoss1 = 1.0e0 * segLoss + 1e1 * depthLoss + 1e2 * icpLoss + 0.0 * tf.reduce_sum(
        objImg - virtObservs.seg)
    totalLoss2 = 1.15 * segLoss + 5.0 * depthLoss + 500.0 * icpLoss

    # get the variables for opt
    optVarsList = scene.getVarsByItemID(
        objID, [varTypeEnum.OBJ_ROT, varTypeEnum.OBJ_TRANS])

    # setup optimizer
    opti1 = Optimizer(totalLoss1,
                      optVarsList,
                      'Adam',
                      learning_rate=0.02 / 2.0)
    opti2 = Optimizer(totalLoss2, optVarsList, 'Adam', learning_rate=0.005)
    optiICP = Optimizer(1e1 * icpLoss, optVarsList, 'Adam', learning_rate=0.01)

    # get the optimization reset ops
    resetOpt1 = tf.variables_initializer(opti1.optimizer.variables())
    resetOpt2 = tf.variables_initializer(opti2.optimizer.variables())

    # tf stuff
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    session.__enter__()
    tf.global_variables_initializer().run()

    # python renderer for rendering object texture
    pyRend = renderScene(h, w)
    modelPath = os.path.join(YCB_MODELS_DIR, configData['obj'])
    pyRend.addObjectFromMeshFile(modelPath, 'obj')
    pyRend.addCamera()
    pyRend.creatcamProjMat(camProp.f, camProp.c, camProp.near, camProp.far)

    # setup the plot window
    plt.ion()
    fig = plt.figure()
    ax1 = fig.add_subplot(2, 2, 1)
    lGT = ax1.imshow(np.zeros((240, 320, 3), dtype=np.float32))
    ax2 = fig.add_subplot(2, 2, 2)
    lRen = ax2.imshow(np.zeros((240, 320, 3), dtype=np.float32))
    ax3 = fig.add_subplot(2, 2, 3)
    lDep = ax3.imshow(np.random.uniform(0, 2, (240, 320)))
    ax4 = fig.add_subplot(2, 2, 4)
    lMask = ax4.imshow(np.random.uniform(0, 2, (240, 320, 3)))

    while (True):
        session.run(resetOpt1)
        session.run(resetOpt2)

        # load new frame
        opti1.runOptimization(session, 1, {loadData: True})
        print(icpLoss.eval(feed_dict={loadData: False}))
        print(segLoss.eval(feed_dict={loadData: False}))
        print(depthLoss.eval(feed_dict={loadData: False}))

        # run the optimization for new frame
        frameID = (realObservs.frameID.eval(
            feed_dict={loadData: False}))[0].decode('UTF-8')
        # opti1.runOptimization(session, 200, {loadData: False})#, logLossFunc=True, lossPlotName=out_dir+'/LossFunc/'+frameID+'_1.png')
        opti2.runOptimization(session, 25, {
            loadData: False
        })  #, logLossFunc=True, lossPlotName='handLoss/'+frameID+'_2.png')

        pyRend.setObjectPose('obj',
                             poseMat.eval(feed_dict={loadData: False})[0].T)
        if USE_PYTHON_RENDERER:
            cRend, dRend = pyRend.render()

        plt.title(frameID)
        depRen = virtObservs.depth.eval(feed_dict={loadData: False})[0]
        depGT = realObservs.depth.eval(feed_dict={loadData: False})[0]
        segRen = virtObservs.seg.eval(feed_dict={loadData: False})[0]
        segGT = realObservs.seg.eval(feed_dict={loadData: False})[0]

        lGT.set_data(
            objImg.eval(feed_dict={loadData: False})[0])  # input image
        if USE_PYTHON_RENDERER:
            lRen.set_data(cRend)  # object rendered in the optimized pose
        lDep.set_data(np.abs(depRen - depGT)[:, :, 0])  # depth map error
        lMask.set_data(np.abs(segRen - segGT)[:, :, :])  # mask error
        plt.savefig(out_dir + '/' + frameID + '.png')
        plt.waitforbuttonpress(0.01)

        transNp = trans.eval(feed_dict={loadData: False})
        rotNp = rot.eval(feed_dict={loadData: False})
        savePickleData(out_dir + '/' + frameID + '.pkl', {
            'rot': rotNp,
            'trans': transNp
        })