def get_dataset(): tetris = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)] ] # zigzag tetris = torch.tensor(tetris, dtype=torch.get_default_dtype()) labels = torch.tensor( [ [+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1 [-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2 [0, 1, 0, 0, 0, 0, 0], # square [0, 0, 1, 0, 0, 0, 0], # line [0, 0, 0, 1, 0, 0, 0], # corner [0, 0, 0, 0, 1, 0, 0], # L [0, 0, 0, 0, 0, 1, 0], # T [0, 0, 0, 0, 0, 0, 1], # zigzag ], dtype=torch.get_default_dtype()) # apply random rotation tetris = torch.einsum('ij,zaj->zai', o3.rand_rot(), tetris) return tetris, labels
def test_rot_to_abc(self): with o3.torch_default_dtype(torch.float64): R = o3.rand_rot() abc = o3.rot_to_abc(R) R2 = o3.rot(*abc) d = (R - R2).norm() / R.norm() self.assertTrue(d < 1e-10, d)
def __init__(self, Rs, act, n): ''' map to a signal on SO3, apply the non linearity point wise and project back the signal on SO3 is the regular representation of SO3 and we can apply a pointwise operation on these representations :param Rs: input representation :param act: activation function :param n: number of point on the sphere (the higher the more accurate) ''' super().__init__() Rs = rs.simplify(Rs) mul0, _, _ = Rs[0] assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs) assert [l for _, l, _ in Rs] == list(range(len(Rs))) assert all(p == 0 for _, l, p in Rs) self.Rs_out = Rs x = [o3.rand_rot() for _ in range(n)] Z = torch.stack([ torch.cat([ o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5 for l in range(len(Rs)) ]) for R in x ]) # [z, lmn] Z.div_(Z.shape[1]**0.5) self.register_buffer('Z', Z) self.act = act
def get_dataset(): tetris = [[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]] # zigzag tetris = torch.tensor(tetris, dtype=torch.get_default_dtype()) labels = torch.arange(len(tetris)) # apply random rotation tetris = torch.stack([torch.einsum("ij,nj->ni", (o3.rand_rot(), x)) for x in tetris]) return tetris, labels