def registration_cost(xyz0, xyz1):
# 	scaled_xyz0, _ = registration.unit_boxify(xyz0)
# 	scaled_xyz1, _ = registration.unit_boxify(xyz1)
	
	f,g = registration.tps_rpm_bij(xyz0, xyz1, rot_reg=1e-3, n_iter=50)
	cost = (registration.tps_reg_cost(f) + registration.tps_reg_cost(g))/2.0
	return cost
def registration_cost(xyz0_p, xyz1_p):
    f,g = registration.tps_rpm_bij(xyz0_p[0], xyz1_p[0], rot_reg=tps_rot_reg, n_iter=tps_n_iter)
    cost = registration.tps_reg_cost(f) + registration.tps_reg_cost(g)
    f = registration.unscale_tps(f, xyz0_p[1], xyz1_p[1])
    g = registration.unscale_tps(g, xyz1_p[1], xyz0_p[1])

    return cost, f, g
def traj_cost(traj1, traj2, n, find_corr=False):
	"""
	Downsamples traj to have n points from start to end.
	"""
	
	ts1 = np.linspace(0,traj1.shape[0],n)
	ts2 = np.linspace(0,traj2.shape[0],n)
	
	xyz1 = lerp(ts1, range(traj1.shape[0]), traj1)
	xyz2 = lerp(ts2, range(traj2.shape[0]), traj2)
		
	if find_corr:
		return registration_cost(xyz1, xyz2)
	else:
		bend_c = 0.05
		rot_c = [1e-3, 1e-3, 1e-3]
		scale_c = 0.1
		f = registration.fit_ThinPlateSpline_RotReg(xyz1, xyz2, bend_c, rot_c, scale_c)
		g = registration.fit_ThinPlateSpline_RotReg(xyz2, xyz1, bend_c, rot_c, scale_c)
		return (registration.tps_reg_cost(f) + registration.tps_reg_cost(g))/2.0
def traj_cost(traj1, traj2, f, g):
    """
    Downsampled traj to have n points from start to end.
    """
    
    fo = registration.fit_ThinPlateSpline(traj1, traj2, traj_bend_c, traj_rot_c)
    go = registration.fit_ThinPlateSpline(traj2, traj1, traj_bend_c, traj_rot_c)
    cost = (registration.tps_reg_cost(fo)+registration.tps_reg_cost(go))/2
    
    if f is not None and g is not None:
        traj1_f = f.transform_points(traj1)
        traj2_g = g.transform_points(traj2)
        fn1 = registration.fit_ThinPlateSpline(traj1_f, traj2, traj_bend_c, traj_rot_c)
        gn1 = registration.fit_ThinPlateSpline(traj2, traj1_f, traj_bend_c, traj_rot_c)
        fn2 = registration.fit_ThinPlateSpline(traj1, traj2_g, traj_bend_c, traj_rot_c)
        gn2 = registration.fit_ThinPlateSpline(traj2_g, traj1, traj_bend_c, traj_rot_c)
        cost_fg = (registration.tps_reg_cost(fn1)+registration.tps_reg_cost(gn1)+
                   registration.tps_reg_cost(fn2) +registration.tps_reg_cost(gn2))/4
        return cost, cost_fg

    return cost