def test_perfect_rigid(): X = loadPoints(Path("tests/fixtures/ref.txt")) P, xform = last(islice(driftRigid(X, X), 100)) np.testing.assert_almost_equal(xform.R, np.eye(2)) np.testing.assert_almost_equal(xform.t, np.zeros(2)) np.testing.assert_almost_equal(xform.s, 1)
def test_affine(): X = loadPoints(Path("tests/fixtures/ref.txt")) Y = loadPoints(Path("tests/fixtures/deg.txt")) expected = loadXform(Path("tests/fixtures/affine.pickle")) P, xform = last(islice(driftAffine(X, Y, w=0.5), 100)) np.testing.assert_almost_equal(xform.B, expected.B) np.testing.assert_almost_equal(xform.t, expected.t)
def test_cpd_prior(): from coherent_point_drift.align import driftRigid rng = np.random.RandomState(4) ndim = 3 R = rotationMatrix(*next(randomRotations(ndim, rng))) t = rng.normal(size=ndim) s = rng.normal(size=1)[0] X = rng.normal(size=(10, ndim)) Y = RigidXform(R, t, s) @ X _, cpd = last(islice(driftRigid(X, Y, w=np.eye(len(X))), 200)) ls = align(X, Y) np.testing.assert_almost_equal(cpd.R, ls.R) np.testing.assert_almost_equal(cpd.t, ls.t) np.testing.assert_almost_equal(cpd.s, ls.s)
def plot(args): from pickle import load from sys import stdin import matplotlib.pyplot as plt from itertools import starmap from numpy.random import seed, random from coherent_point_drift.geometry import rigidXform, RMSD from coherent_point_drift.util import last from math import degrees seed(4) reference = load(stdin.buffer) rmsds = [] rotations = [] for degradation, (fit, rmsd) in loadAll(stdin.buffer): rmsds.append(rmsd) rotations.append(degrees(degradation[0][0])) plt.figure(0) plt.plot(rmsd, alpha=0.3) plt.figure(1) color = random(3) degraded = degrade(reference, *degradation) plt.scatter(degraded[:, 0], degraded[:, 1], marker='o', color=color, alpha=0.2) fitted = rigidXform(degraded, *last(fit)) plt.scatter(fitted[:, 0], fitted[:, 1], marker='+', color=color, alpha=0.4) plt.scatter(reference[:, 0], reference[:, 1], marker='v', color='black') min_rmsds= map(min, rmsds) rotation_rmsds = sorted(zip(rotations, min_rmsds), key=lambda x: x[0]) plt.figure(2) plt.plot(*zip(*rotation_rmsds)) plt.show()