예제 #1
0
def fitting_methods_equivalent_without_rot_reg():
    pts0 = np.random.randn(100,3)
    pts1 = np.random.randn(100,3)
    lin_ag, trans_g, w_ng = tps.tps_fit(pts0, pts1, .01, 0)    
    lin2_ag, trans2_g, w2_ng = tps.tps_fit2(pts0, pts1, .01, 0)
    assert np.allclose(lin_ag, lin2_ag)
    assert np.allclose(trans_g, trans2_g)
    assert np.allclose(w_ng, w2_ng)
예제 #2
0
def tps_regrot_with_quad_cost():
    x_na = np.random.randn(100,3)
    y_ng = np.random.randn(100,3)
    bend_coef = .1
    rot_coef = 19
    def rfunc(b):
        return rot_coef*((b - np.eye(3))**2).sum()
    correct_lin_ag, correct_trans_g, correct_w_ng = tps.tps_fit2(x_na, y_ng, bend_coef, rot_coef)
    lin_ag, trans_g, w_ng = tps.tps_fit_regrot(x_na, y_ng, bend_coef, rfunc, max_iter=20)
    assert np.allclose(correct_trans_g, trans_g, atol=1e-2)    
    assert np.allclose(correct_lin_ag, lin_ag, atol=1e-2)
    assert np.allclose(correct_w_ng, w_ng,atol=1e-2)