def fit_tps(x_gt, x, plot=True):
    """
    Fits a thin-plate spline model to x (source points) and x_gt (ground-truth target points).
    This transform can be used to correct the state-dependent hydra errors.
    
    X, X_GT are of the shape : Nx3
    
    """
    bend_coeff = 0.1  ## increase this to make the tps-interpolation more smooth, decrease to fit better.
    f_tps = registration.fit_ThinPlateSpline(x, x_gt, bend_coef = bend_coeff, rot_coef = 0.01*bend_coeff)

    if plot:
        plot_reqs = plot_warping(f_tps.transform_points, x, x_gt, fine=False, draw_plinks=True)
        plotter   = PlotterInit()
        for req in plot_reqs:
            plotter.request(req)

    return f_tps
def traj_cost(traj1, traj2, f, g):
    """
    Downsampled traj to have n points from start to end.
    """
    
    fo = registration.fit_ThinPlateSpline(traj1, traj2, traj_bend_c, traj_rot_c)
    go = registration.fit_ThinPlateSpline(traj2, traj1, traj_bend_c, traj_rot_c)
    cost = (registration.tps_reg_cost(fo)+registration.tps_reg_cost(go))/2
    
    if f is not None and g is not None:
        traj1_f = f.transform_points(traj1)
        traj2_g = g.transform_points(traj2)
        fn1 = registration.fit_ThinPlateSpline(traj1_f, traj2, traj_bend_c, traj_rot_c)
        gn1 = registration.fit_ThinPlateSpline(traj2, traj1_f, traj_bend_c, traj_rot_c)
        fn2 = registration.fit_ThinPlateSpline(traj1, traj2_g, traj_bend_c, traj_rot_c)
        gn2 = registration.fit_ThinPlateSpline(traj2_g, traj1, traj_bend_c, traj_rot_c)
        cost_fg = (registration.tps_reg_cost(fn1)+registration.tps_reg_cost(gn1)+
                   registration.tps_reg_cost(fn2) +registration.tps_reg_cost(gn2))/4
        return cost, cost_fg

    return cost