예제 #1
0
 def test_tpsrpm_objective_monotonicity(self):
     n_iter = 10
     em_iter = 10
     reg_factory = TpsRpmRegistrationFactory(n_iter=n_iter, em_iter=em_iter, f_solver_factory=solver.AutoTpsSolverFactory(use_cache=False))
     
     objs = np.zeros((n_iter, em_iter))
     def callback(i, i_em, x_nd, y_md, xtarg_nd, wt_n, f, corr_nm, rad):
         objs[i, i_em] = TpsRpmRegistration.get_objective2(x_nd, y_md, f, corr_nm, rad).sum()
     
     reg = reg_factory.register(self.demos.values()[0], self.test_scene_state, callback=callback)
     print np.diff(objs, axis=1) <= 0 # TODO assert when monotonicity is more robust
예제 #2
0
    def test_tpsrpm_objective_monotonicity(self):
        n_iter = 10
        em_iter = 10
        reg_factory = TpsRpmRegistrationFactory(
            n_iter=n_iter,
            em_iter=em_iter,
            f_solver_factory=solver.AutoTpsSolverFactory(use_cache=False))

        objs = np.zeros((n_iter, em_iter))

        def callback(i, i_em, x_nd, y_md, xtarg_nd, wt_n, f, corr_nm, rad):
            objs[i, i_em] = TpsRpmRegistration.get_objective2(
                x_nd, y_md, f, corr_nm, rad).sum()

        reg = reg_factory.register(self.demos.values()[0],
                                   self.test_scene_state,
                                   callback=callback)
        print np.diff(
            objs, axis=1) <= 0  # TODO assert when monotonicity is more robust
예제 #3
0
    def test_tps_objective(self):
        reg_factory = TpsRpmRegistrationFactory(
            {}, f_solver_factory=solver.CpuTpsSolverFactory(use_cache=False))
        reg = reg_factory.register(self.demos.values()[0],
                                   self.test_scene_state)

        x_na = reg.f.x_na
        y_ng = reg.f.y_ng
        wt_n = reg.f.wt_n
        rot_coef = reg.f.rot_coef
        bend_coef = reg.f.bend_coef

        # code from tps_fit3
        n, d = x_na.shape

        K_nn = tps.tps_kernel_matrix(x_na)
        Q = np.c_[np.ones((n, 1)), x_na, K_nn]
        rot_coefs = np.ones(d) * rot_coef if np.isscalar(
            rot_coef) else np.asarray(rot_coef)
        A = np.r_[np.zeros((d + 1, d + 1)), np.c_[np.ones((n, 1)), x_na]].T

        WQ = wt_n[:, None] * Q
        QWQ = Q.T.dot(WQ)
        H = QWQ
        H[d + 1:, d + 1:] += bend_coef * K_nn
        H[1:d + 1, 1:d + 1] += np.diag(rot_coefs)

        f = -WQ.T.dot(y_ng)
        f[1:d + 1, 0:d] -= np.diag(rot_coefs)

        # optimum point
        theta = np.r_[reg.f.trans_g[None, :], reg.f.lin_ag, reg.f.w_ng]

        # equality constraint
        self.assertTrue(np.allclose(A.dot(theta), np.zeros((4, 3))))
        # objective
        obj = np.trace(theta.T.dot(H.dot(theta))) + 2*np.trace(f.T.dot(theta)) \
        + np.trace(y_ng.T.dot(wt_n[:,None]*y_ng)) + rot_coefs.sum() # constant
        self.assertTrue(np.allclose(obj, reg.f.get_objective().sum()))
예제 #4
0
 def test_tps_objective(self):
     reg_factory = TpsRpmRegistrationFactory({}, f_solver_factory=solver.CpuTpsSolverFactory(use_cache=False))
     reg = reg_factory.register(self.demos.values()[0], self.test_scene_state)
     
     x_na = reg.f.x_na
     y_ng = reg.f.y_ng
     wt_n = reg.f.wt_n
     rot_coef = reg.f.rot_coef
     bend_coef = reg.f.bend_coef
     
     # code from tps_fit3
     n,d = x_na.shape
     
     K_nn = tps.tps_kernel_matrix(x_na)
     Q = np.c_[np.ones((n,1)), x_na, K_nn]
     rot_coefs = np.ones(d) * rot_coef if np.isscalar(rot_coef) else np.asarray(rot_coef)
     A = np.r_[np.zeros((d+1,d+1)), np.c_[np.ones((n,1)), x_na]].T
     
     WQ = wt_n[:,None] * Q
     QWQ = Q.T.dot(WQ)
     H = QWQ
     H[d+1:,d+1:] += bend_coef * K_nn
     H[1:d+1, 1:d+1] += np.diag(rot_coefs)
     
     f = -WQ.T.dot(y_ng)
     f[1:d+1,0:d] -= np.diag(rot_coefs)
     
     # optimum point
     theta = np.r_[reg.f.trans_g[None,:], reg.f.lin_ag, reg.f.w_ng]
     
     # equality constraint
     self.assertTrue(np.allclose(A.dot(theta), np.zeros((4,3))))
     # objective
     obj = np.trace(theta.T.dot(H.dot(theta))) + 2*np.trace(f.T.dot(theta)) \
     + np.trace(y_ng.T.dot(wt_n[:,None]*y_ng)) + rot_coefs.sum() # constant
     self.assertTrue(np.allclose(obj, reg.f.get_objective().sum()))