from quats import Quat, rot_dist_w_syms, rand_quats
from symmetries import hcp_syms, fcc_syms

import matplotlib.pyplot as plt
from time import time

N = 3000
n_bins = 100

use_torch = True

if use_torch:
    hcp_syms, fcc_syms = hcp_syms.to_torch(), fcc_syms.to_torch()

q1 = rand_quats(N, use_torch)
q2 = rand_quats(N, use_torch)

t1 = time()
dists_hcp = rot_dist_w_syms(q1, q2, hcp_syms)
dists_fcc = rot_dist_w_syms(q1, q2, fcc_syms)
print(f'{time()-t1:0.5f} seconds to compute {2*N} misorientations')

print(type(dists_fcc))

fig, axes = plt.subplots(1, 2)

axes[0].hist(dists_hcp, bins=n_bins)
axes[1].hist(dists_fcc, bins=n_bins)

axes[0].set_title('HCP')
axes[1].set_title('FCC')
Ejemplo n.º 2
0
from quats import Quat, rot_dist_w_syms, rand_quats
from symmetries import hcp_syms, fcc_syms

N = 7

use_torch = False

q_est = rand_quats(N, use_torch=True)
q_gt = rand_quats(N, use_torch=True)

q_est.X.requires_grad = True

print(q_gt)

dists_hcp = rot_dist_w_syms(q_est, q_gt, hcp_syms.to_torch())
dists_fcc = rot_dist_w_syms(q_est, q_gt, fcc_syms.to_torch())

loss_hcp = (dists_hcp**2).sum()
loss_fcc = (dists_fcc**2).sum()

loss_hcp.backward()

print(q_est.X.grad)
from quats import Quat, rand_quats

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

N = 100

q = rand_quats(N)

R = q.X[:, 1:] / q.X[:, 0][:, None]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(*R.T)

plt.show()
Ejemplo n.º 4
0
#fcc_r2 = torch.array(np.eye(4)[:2])
fcc_r2 = torch.eye(4)[:2]

fcc_r3 = torch.zeros((4, 4))
fcc_r3[:, 0] = torch.cos(pi / 4 * torch.arange(4))
fcc_r3[:, 3] = torch.sin(pi / 4 * torch.arange(4))

fcc_r12 = outer_prod(fcc_r1, fcc_r2)
fcc_syms = outer_prod(fcc_r12, fcc_r3).reshape((-1, 4))

if __name__ == '__main__':

    from plotting_utils import *

    np.random.seed(1)
    q1 = rand_quats(())

    rhomb_wire = path2prism(rhomb_path)
    all_rots = outer_prod(q1, hcp_syms)
    all_wires = rotate(all_rots, rhomb_wire)
    all_axes = rotate(all_rots, rhomb_axes)

    def setup_axes(m, n):
        r = np.sqrt(2)
        fig = plt.figure()
        axes = [
            fig.add_subplot(m, n, i + 1, projection='3d') for i in range(m * n)
        ]
        for a in axes:
            a.set_xlim(-r, r)
            a.set_ylim(-r, r)