Exemplo n.º 1
0
    def __init__(self, demos, actionfile=None, 
                 n_iter=settings.N_ITER, em_iter=settings.EM_ITER, 
                 reg_init=settings.REG[0], reg_final=settings.REG[1], 
                 rad_init=settings.RAD[0], rad_final=settings.RAD[1], 
                 rot_reg=settings.ROT_REG, 
                 outlierprior=settings.OUTLIER_PRIOR, outlierfrac=settings.OURLIER_FRAC, 
                 prior_fn=None, 
                 f_solver_factory=solver.AutoTpsSolverFactory(), 
                 g_solver_factory=solver.AutoTpsSolverFactory(use_cache=False)):
        if not lfd.registration._has_cuda:
            raise NotImplementedError("CUDA not installed")
        super(BatchGpuTpsRpmBijRegistrationFactory, self).__init__(demos=demos, 
                                                              n_iter=n_iter, em_iter=em_iter, 
                                                              reg_init=reg_init, reg_final=reg_final, 
                                                              rad_init=rad_init, rad_final=rad_final, 
                                                              rot_reg=rot_reg, 
                                                              outlierprior=outlierprior, outlierfrac=outlierfrac, 
                                                              prior_fn=prior_fn, 
                                                              f_solver_factory=f_solver_factory, g_solver_factory=g_solver_factory)

        self.actionfile = actionfile
        if self.actionfile:
            self.bend_coefs = tps.loglinspace(self.reg_init, self.reg_final, self.n_iter)
            self.src_ctx = GPUContext(self.bend_coefs)
            self.src_ctx.read_h5(actionfile)
        self.warn_clip_cloud = True
Exemplo n.º 2
0
 def set_landmark_file(self, landmarkf):
     self.landmark_ctx = GPUContext()
     self.landmark_ctx.read_h5(landmarkf)
     self.landmark_targ_ctx = TgtContext(self.landmark_ctx)
     self.weights = np.zeros(self.src_ctx.N + self.landmark_ctx.N +
                             MulFeats.N_costs)