from quats import Quat, rot_dist_w_syms, rand_quats from symmetries import hcp_syms, fcc_syms import matplotlib.pyplot as plt from time import time N = 3000 n_bins = 100 use_torch = True if use_torch: hcp_syms, fcc_syms = hcp_syms.to_torch(), fcc_syms.to_torch() q1 = rand_quats(N, use_torch) q2 = rand_quats(N, use_torch) t1 = time() dists_hcp = rot_dist_w_syms(q1, q2, hcp_syms) dists_fcc = rot_dist_w_syms(q1, q2, fcc_syms) print(f'{time()-t1:0.5f} seconds to compute {2*N} misorientations') print(type(dists_fcc)) fig, axes = plt.subplots(1, 2) axes[0].hist(dists_hcp, bins=n_bins) axes[1].hist(dists_fcc, bins=n_bins) axes[0].set_title('HCP') axes[1].set_title('FCC')
from quats import Quat, rot_dist_w_syms, rand_quats from symmetries import hcp_syms, fcc_syms N = 7 use_torch = False q_est = rand_quats(N, use_torch=True) q_gt = rand_quats(N, use_torch=True) q_est.X.requires_grad = True print(q_gt) dists_hcp = rot_dist_w_syms(q_est, q_gt, hcp_syms.to_torch()) dists_fcc = rot_dist_w_syms(q_est, q_gt, fcc_syms.to_torch()) loss_hcp = (dists_hcp**2).sum() loss_fcc = (dists_fcc**2).sum() loss_hcp.backward() print(q_est.X.grad)
from quats import Quat, rand_quats import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D N = 100 q = rand_quats(N) R = q.X[:, 1:] / q.X[:, 0][:, None] fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(*R.T) plt.show()
#fcc_r2 = torch.array(np.eye(4)[:2]) fcc_r2 = torch.eye(4)[:2] fcc_r3 = torch.zeros((4, 4)) fcc_r3[:, 0] = torch.cos(pi / 4 * torch.arange(4)) fcc_r3[:, 3] = torch.sin(pi / 4 * torch.arange(4)) fcc_r12 = outer_prod(fcc_r1, fcc_r2) fcc_syms = outer_prod(fcc_r12, fcc_r3).reshape((-1, 4)) if __name__ == '__main__': from plotting_utils import * np.random.seed(1) q1 = rand_quats(()) rhomb_wire = path2prism(rhomb_path) all_rots = outer_prod(q1, hcp_syms) all_wires = rotate(all_rots, rhomb_wire) all_axes = rotate(all_rots, rhomb_axes) def setup_axes(m, n): r = np.sqrt(2) fig = plt.figure() axes = [ fig.add_subplot(m, n, i + 1, projection='3d') for i in range(m * n) ] for a in axes: a.set_xlim(-r, r) a.set_ylim(-r, r)