def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) Rs_in = [(1, 0, +1)] f = GatedBlockParity(Operation=C, Rs_in=Rs_in, Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)]) Rs_out = f.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -rot( *abc ) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rep(Rs_in, *abc, parity=1) D_out = rep(Rs_out, *abc, parity=1) c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in]) feat = torch.randn(batch, n_atoms, c) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = f(feat, geo) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = f(feat @ D_in.t(), geo @ rot_geo.t()) return (RF - FR).norm() < 10e-5 * RF.norm()
def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) f = GatedBlock(partial(C, Rs_in), Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) abc = torch.randn(3) rot_geo = rot(*abc) D_in = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo))) x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def check_rotation(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) Rs_in = [(1, 0), (1, 1)] Rs_out = [(1, 0), (1, 1), (1, 2)] f = GatedBlock( partial(C, Rs_in), Rs_out, scalar_activation=sigmoid, gate_activation=absolute, ) # Setup the data. The geometry, input features, and output features must all rotate. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = rot(*abc) D_in = rep(Rs_in, *abc) D_out = rep(Rs_out, *abc) c = sum([mul * (2 * l + 1) for mul, l in Rs_in]) feat = torch.randn(batch, n_atoms, c) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix. # Test equivariance. F = f(feat, geo) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = f(feat @ D_in.t(), geo @ rot_geo.t()) return (RF - FR).norm() < 10e-5 * RF.norm()
def rotation_from_orientations(t1, f1, t2, f2): from e3nn.SO3 import xyz_to_angles, rot zero = t1.new_tensor(0) r_e_t1 = rot(*xyz_to_angles(t1), zero) r_e_t2 = rot(*xyz_to_angles(t2), zero) f1_e = r_e_t1.t() @ f1 f2_e = r_e_t2.t() @ f2 c = torch.atan2(f2_e[1], f2_e[0]) - torch.atan2(f1_e[1], f1_e[0]) r_f1_f2 = rot(zero, zero, c) r = r_e_t2 @ r_f1_f2 @ r_e_t1.t() # t2 = r @ t1 # f2 ~ r @ f1 return r
def rotate(x, alpha, beta, gamma): t = time_logging.start() y = x.cpu().detach().numpy() R = rot(alpha, beta, gamma) if x.ndimension() == 4: for i in range(y.shape[0]): y[i] = rotate_scalar(y[i], R) else: y = rotate_scalar(y, R) x = x.new_tensor(y) if x.ndimension() == 4 and x.size(0) == 3: rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device) x = torch.einsum("ij,jxyz->ixyz", (rep, x)) time_logging.end("rotate", t) return x
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)]) W1 = D_out @ k(r) # [i, j] W2 = k(rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def _get_random_affine_trafo(self): if self.rotate: alpha,beta,gamma = np.pi*np.array([2,1,2])*np.random.rand(3) aff = rot(alpha,beta,gamma) else: aff = np.eye(3) # only non-homogeneous coord part fl = (-1)**np.random.randint(low=0, high=2) if self.flip else 1 if self.scale is not None: sx,sy,sz = np.random.uniform(low=self.scale[0], high=self.scale[1], size=3) else: sx,sy,sz = 1 aff[:,0] *= sx*fl aff[:,1] *= sy aff[:,2] *= sz center = self.vol_shape/2 offset = center - [email protected] # correct offset to apply trafo around center if self.translate: offset += np.random.uniform(low=-.5, high=.5, size=3) return partial(affine_transform, matrix=aff, offset=offset)
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) f = GatedBlock(partial(C, Rs_in), Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) abc = torch.randn(3) D_in = direct_sum(*[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = direct_sum(*[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ rot(*abc).t() y = f(x, geo, dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(rx, rgeo) - ry).norm(), 1e-10 * ry.norm())
def test1(self): """Test irr_repr and clebsch_gordan equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = clebsch_gordan(l_out, l_in, l_f) abc = torch.randn(3) D_in = irr_repr(l_in, *abc) D_out = irr_repr(l_out, *abc) Y = spherical_harmonics_xyz(l_f, r @ rot(*abc).t()) W = torch.einsum("ijk,kz->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = spherical_harmonics_xyz(l_f, r) W = torch.einsum("ijk,kz->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def test6(self): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars) abc = torch.randn(3) rot_geo = -rot(*abc) D_in = direct_sum(*[ p * irr_repr(l, *abc) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = direct_sum(*[ p * irr_repr(l, *abc) for mul, l, p in f.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo))) x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def get_volumes(size=20, pad=8, rotate=False, rotate90=False): assert size >= 4 tetris_tensorfields = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 0, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] labels = np.arange(len(tetris_tensorfields)) tetris_vox = [] for shape in tetris_tensorfields: volume = np.zeros((4, 4, 4)) for xi_coords in shape: volume[xi_coords] = 1 volume = zoom(volume, size / 4, order=0) volume = np.pad(volume, pad, 'constant') if rotate: a, c = np.random.rand(2) * 2 * np.pi b = np.arccos(np.random.rand() * 2 - 1) volume = rotate_scalar(volume, rot(a, b, c)) if rotate90: volume = rot_volume_90(volume) tetris_vox.append(volume[np.newaxis, ...]) tetris_vox = np.stack(tetris_vox).astype(np.float32) return tetris_vox, labels
def train_step(item, top, front, model, optimizer): from e3nn.SO3 import rot model.train() abc = 15 * (torch.rand(3) * 2 - 1) / 180 * math.pi r = rot(*abc).to(item.device) tip = torch.cat( [torch.einsum("ij,...j->...i", (r, x)) for x in [top, front]], dim=1) tip = tip.view(tip.size(0), tip.size(1), 1, 1, 1).expand(tip.size(0), tip.size(1), item.size(2), item.size(3), item.size(4)) input = torch.cat([item, tip], dim=1) prediction = model(input) pred_top, pred_front = prediction[:, :3], prediction[:, 3:] overlap_top = overlap(pred_top, top) overlap_front = overlap(pred_front, front) overlap_orth = overlap(pred_top, pred_front) optimizer.zero_grad() (-overlap_top - overlap_front + overlap_orth.abs()).mean().backward() optimizer.step() a = torch.mean( torch.tensor([ angle_from_rotation( rotation_from_orientations(top[i], front[i], pred_top[i], pred_front[i])) for i in range(len(item)) ])) return overlap_top.mean().item(), overlap_front.mean().item( ), overlap_orth.abs().mean().item(), a.item()