def fit_ThinPlateSpline_corr(x_nd, y_md, corr_nm, l, rot_reg, x_weights = None): wt_n = corr_nm.sum(axis=1) if np.any(wt_n == 0): inlier = wt_n != 0 x_nd = x_nd[inlier,:] wt_n = wt_n[inlier,:] x_weights = x_weights[inlier] xtarg_nd = (corr_nm[inlier,:]/wt_n[:,None]).dot(y_md) else: xtarg_nd = (corr_nm/wt_n[:,None]).dot(y_md) if x_weights is not None: if x_weights.ndim > 1: wt_n=wt_n[:,None]*x_weights else: wt_n=wt_n*x_weights f = fit_ThinPlateSpline(x_nd, xtarg_nd, bend_coef = l, wt_n = wt_n, rot_coef = rot_reg) f._bend_coef = l f._wt_n = wt_n f._rot_coef = rot_reg f._cost = tps.tps_cost(f.lin_ag, f.trans_g, f.w_ng, f.x_na, xtarg_nd, l, wt_n=wt_n)/wt_n.mean() return f
def rpm_em_step_stat(x_nd, y_md, l, T, rot_reg, prev_f, vis_cost_xy = None, outlierprior = 1e-2, normalize_iter = 20, T0 = .04, user_data=None): """ Statiscal interpretation of the RPM EM step """ n,d = x_nd.shape m,_ = y_md.shape xwarped_nd = prev_f.transform_points(x_nd) dist_nm = ssd.cdist(xwarped_nd, y_md, 'sqeuclidean') outlier_dist_1m = ssd.cdist(xwarped_nd.mean(axis=0)[None,:], y_md, 'sqeuclidean') outlier_dist_n1 = ssd.cdist(xwarped_nd, y_md.mean(axis=0)[None,:], 'sqeuclidean') # Note: proportionality constants within a column can be ignored since Sinkorn balancing normalizes the columns first prob_nm = np.exp( -(dist_nm / (2*T)) + (outlier_dist_1m / (2*T0)) ) / np.sqrt(T) # divide by np.exp( outlier_dist_1m / (2*T0) ) to prevent prob collapsing to zero if vis_cost_xy != None: pi = np.exp( -vis_cost_xy ) pi /= pi.sum(axis=0)[None,:] # normalize along columns; these are proper probabilities over j = 1,...,N prob_nm *= (1. - outlierprior) * pi else: prob_nm *= (1. - outlierprior) / float(n) outlier_prob_1m = outlierprior * np.ones((1,m)) / np.sqrt(T0) # divide by np.exp( outlier_dist_1m / (2*T0) ) outlier_prob_n1 = np.exp( -outlier_dist_n1 / (2*T0) ) / np.sqrt(T0) prob_NM = np.empty((n+1, m+1), 'f4') prob_NM[:n, :m] = prob_nm prob_NM[:n, m][:,None] = outlier_prob_n1 prob_NM[n, :m][None,:] = outlier_prob_1m prob_NM[n, m] = 0 r_N, c_M = sinkhorn_balance_coeffs(prob_NM, normalize_iter) prob_NM *= r_N[:,None] prob_NM *= c_M[None,:] # prob_NM needs to be row-normalized at this point corr_nm = prob_NM[:n, :m] wt_n = corr_nm.sum(axis=1) # discard points that are outliers (i.e. their total correspondence is smaller than 1e-2) inlier = wt_n > 1e-2 if np.any(~inlier): x_nd = x_nd[inlier,:] wt_n = wt_n[inlier,:] xtarg_nd = (corr_nm[inlier,:]/wt_n[:,None]).dot(y_md) else: xtarg_nd = (corr_nm/wt_n[:,None]).dot(y_md) f = fit_ThinPlateSpline(x_nd, xtarg_nd, bend_coef = l, wt_n = wt_n, rot_coef = rot_reg) f._bend_coef = l f._rot_coef = rot_reg f._cost = tps.tps_cost(f.lin_ag, f.trans_g, f.w_ng, f.x_na, xtarg_nd, l, wt_n=wt_n)/wt_n.mean() return f, corr_nm