def check_network_equivariance(network, Rs_in, Rs_out, vector_output: bool = False, batch: int = 2, n_atoms: int = 6) -> bool: """ from user func(input_data, alpha, beta, gamma, parity) from user func(output_data, alpha, beta, gamma, parity) Checks rotation equivariance for a function. Network output transforms with the wigner d matrices. :param network: Often the forward pass of an object inheriting from torch.nn.Module :param Rs_in: Rs for the input of the network. :param Rs_out: Rs for the output of the network. :param vector_output: When True, the network outputs vectors which have already been converted to xyz basis. :param batch: Batch size of test case. :param n_atoms: Number of atoms of test case. :return: None """ def check_parity(Rs): """Does Rs require a test of parity equivariance.""" if any(p != 0 for _, _, p in normalizeRs(Rs)): return True else: return False # Make sure that all parameters are on the same device. Set that to calculation device. devices = [i.device for i in network.parameters()] device = devices[0] assert all([device == i for i in devices]) Rs_in = normalizeRs(Rs_in) Rs_out = normalizeRs(Rs_out) parity = check_parity(Rs_in) or check_parity(Rs_out) abc = torch.randn(3, device=device) # Rotation seed of euler angles. D_in = rep(Rs_in, *abc) geo_rotation_matrix = -rot(*abc) if parity else rot(*abc) # Geometry is odd parity. D_out = rep(Rs_out, *abc) c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in]) feat = torch.randn(batch, n_atoms, c, device=device) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3, device=device) # Transforms with rotation matrix. F = network(feat, geo) if vector_output: RF = torch.einsum("ij,zkj->zki", geo_rotation_matrix, F) else: RF = torch.einsum("ij,zkj->zki", D_out, F) FR = network(feat @ D_in.t(), geo @ geo_rotation_matrix.t()) # [batch, feat, N] return (RF - FR).norm() < 10e-5 * RF.norm()
def rotate(x, alpha): t = time_logging.start() y = x.cpu().detach().numpy() R = rot(alpha, 0, 0) y = rotate_scalar(y, R) x = x.new_tensor(y) time_logging.end("rotate", t) return x
def rotation_from_orientations(t1, f1, t2, f2): from se3cnn.SO3 import x_to_alpha_beta, rot zero = t1.new_tensor(0) r_e_t1 = rot(*x_to_alpha_beta(t1), zero) r_e_t2 = rot(*x_to_alpha_beta(t2), zero) f1_e = r_e_t1.t() @ f1 f2_e = r_e_t2.t() @ f2 c = torch.atan2(f2_e[1], f2_e[0]) - torch.atan2(f1_e[1], f1_e[0]) r_f1_f2 = rot(zero, zero, c) r = r_e_t2 @ r_f1_f2 @ r_e_t1.t() # t2 = r @ t1 # f2 ~ r @ f1 return r
def rotate(x, alpha, beta, gamma): t = time_logging.start() y = x.cpu().detach().numpy() R = rot(alpha, beta, gamma) if x.ndimension() == 4: for i in range(y.shape[0]): y[i] = rotate_scalar(y[i], R) else: y = rotate_scalar(y, R) x = x.new_tensor(y) if x.ndimension() == 4 and x.size(0) == 3: rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device) x = torch.einsum("ij,jxyz->ixyz", (rep, x)) time_logging.end("rotate", t) return x
def _get_random_affine_trafo(self): if self.rotate: alpha, beta, gamma = np.pi * np.array([2, 1, 2 ]) * np.random.rand(3) aff = rot(alpha, beta, gamma) else: aff = np.eye(3) # only non-homogeneous coord part fl = (-1)**np.random.randint(low=0, high=2) if self.flip else 1 if self.scale is not None: sx, sy, sz = np.random.uniform(low=self.scale[0], high=self.scale[1], size=3) else: sx, sy, sz = 1 aff[:, 0] *= sx * fl aff[:, 1] *= sy aff[:, 2] *= sz center = self.vol_shape / 2 offset = center - center @ aff.T # correct offset to apply trafo around center if self.translate: offset += np.random.uniform(low=-.5, high=.5, size=3) return partial(affine_transform, matrix=aff, offset=offset)
def get_volumes(size=20, pad=8, rotate=False, rotate90=False): assert size >= 4 tetris_tensorfields = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 0, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] labels = np.arange(len(tetris_tensorfields)) tetris_vox = [] for shape in tetris_tensorfields: volume = np.zeros((4, 4, 4)) for xi_coords in shape: volume[xi_coords] = 1 volume = zoom(volume, size / 4, order=0) volume = np.pad(volume, pad, 'constant') if rotate: a, c = np.random.rand(2) * 2 * np.pi b = np.arccos(np.random.rand() * 2 - 1) volume = rotate_scalar(volume, rot(a, b, c)) if rotate90: volume = rot_volume_90(volume) tetris_vox.append(volume[np.newaxis, ...]) tetris_vox = np.stack(tetris_vox).astype(np.float32) return tetris_vox, labels
def train_step(item, top, front, model, optimizer): from se3cnn.SO3 import rot model.train() abc = 15 * (torch.rand(3) * 2 - 1) / 180 * math.pi r = rot(*abc).to(item.device) tip = torch.cat( [torch.einsum("ij,...j->...i", (r, x)) for x in [top, front]], dim=1) tip = tip.view(tip.size(0), tip.size(1), 1, 1, 1).expand(tip.size(0), tip.size(1), item.size(2), item.size(3), item.size(4)) input = torch.cat([item, tip], dim=1) prediction = model(input) pred_top, pred_front = prediction[:, :3], prediction[:, 3:] overlap_top = overlap(pred_top, top) overlap_front = overlap(pred_front, front) overlap_orth = overlap(pred_top, pred_front) optimizer.zero_grad() (-overlap_top - overlap_front + overlap_orth.abs()).mean().backward() optimizer.step() a = torch.mean( torch.tensor([ angle_from_rotation( rotation_from_orientations(top[i], front[i], pred_top[i], pred_front[i])) for i in range(len(item)) ])) return overlap_top.mean().item(), overlap_front.mean().item( ), overlap_orth.abs().mean().item(), a.item()