def test_equivariance(self, tol=1e-4): self.model.eval() mb = self.mb outs = self.model(mb).cpu().data.numpy() #print('first done') outs2 = self.model(mb).cpu().data.numpy() #print('second done') bs = mb['positions'].shape[0] q = torch.randn(bs, 1, 4, device=mb['positions'].device, dtype=mb['positions'].dtype) q /= norm(q, dim=-1).unsqueeze(-1) theta_2 = torch.atan2(norm(q[..., 1:], dim=-1), q[..., 0]).unsqueeze(-1) so3_elem = theta_2 * q[..., 1:] Rs = SO3.exp(so3_elem) #print(Rs.shape) #print(mb['positions'].shape) mb['positions'] = (Rs @ mb['positions'].unsqueeze(-1)).squeeze(-1) outs3 = self.model(mb).cpu().data.numpy() diff = np.abs(outs2 - outs).mean() / np.abs(outs).mean() print('run through twice rel err:', diff) diff = np.abs(outs2 - outs3).mean() / np.abs(outs2).mean() print('rotation equivariance rel err:', diff) self.assertTrue(diff < tol)
def load(config, **unused_kwargs): if config.group == "SE3": group = SE3(0.2) elif config.group == "SO3": group = SO3(0.2) elif config.group == "T3": group = T(3) elif config.group == "Trivial3": group = Trivial(3) else: raise ValueError(f"{config.group} is and invalid group") torch.manual_seed(config.model_seed) # TODO: temp fix of seed predictor = MolecLieResNet( config.num_species, config.charge_scale, group=group, aug=config.data_augmentation, k=config.channels, nbhd=config.nbhd_size, act=config.activation_function, bn=config.batch_norm, mean=config.mean_pooling, num_layers=config.num_layers, fill=config.fill, liftsamples=config.lift_samples, lie_algebra_nonlinearity=config.lie_algebra_nonlinearity, ) molecule_predictor = MoleculePredictor(predictor, config.task, config.ds_stats) return molecule_predictor, f"MoleculeLieResNet_{config.group}"
def forward(self, x): if not self.training and self.train_only: return x coords, vals, mask = x # coords (bs,n,c) Rs = SO3().sample(coords.shape[0], 1, device=coords.device, dtype=coords.dtype) return ((Rs @ coords.unsqueeze(-1)).squeeze(-1), vals, mask)
def load(config, **unused_kwargs): if config.group == "SE3": group = SE3(0.2) elif config.group == "SO3": group = SO3(0.2) elif config.group == "T3": group = T(3) elif config.group == "Trivial3": group = Trivial(3) else: raise ValueError(f"{config.group} is and invalid group") torch.manual_seed(config.model_seed) predictor = MoleculeEquivariantTransformer( config.num_species, config.charge_scale, architecture=config.architecture, group=group, aug=config.data_augmentation, dim_hidden=config.dim_hidden, num_layers=config.num_layers, num_heads=config.num_heads, global_pool=True, global_pool_mean=config.mean_pooling, liftsamples=config.lift_samples, block_norm=config.block_norm, output_norm=config.output_norm, kernel_norm=config.kernel_norm, kernel_type=config.kernel_type, kernel_dim=config.kernel_dim, kernel_act=config.activation_function, fill=config.fill, mc_samples=config.mc_samples, attention_fn=config.attention_fn, feature_embed_dim=config.feature_embed_dim, max_sample_norm=config.max_sample_norm, lie_algebra_nonlinearity=config.lie_algebra_nonlinearity, ) # predictor.net[-1][-1].weight.data = predictor.net[-1][-1].weight * (0.205 / 0.005) # predictor.net[-1][-1].bias.data = predictor.net[-1][-1].bias - (0.196 + 0.40) molecule_predictor = MoleculePredictor(predictor, config.task, config.ds_stats) return ( molecule_predictor, f"MoleculeEquivariantTransformer_{config.group}_{config.architecture}", )
def load(config, **unused_kwargs): if config.group == "SE3": group = SE3(0.2) if config.group == "SE2": group = SE2(0.2) elif config.group == "SO3": group = SO3(0.2) elif config.group == "T3": group = T(3) elif config.group == "T2": group = T(2) elif config.group == "Trivial3": group = Trivial(3) else: raise ValueError(f"{config.group} is and invalid group") if config.content_type == "centroidal": dim_input = 2 feature_function = constant_features elif config.content_type == "constant": dim_input = 1 feature_function = lambda X, presence: torch.ones(X.shape[:-1], dtype=X.dtype, device=X.device).unsqueeze(-1) elif config.content_type == "pairwise_distances": dim_input = config.patterns_reps * 17 - 1 feature_function = pairwise_distance_features # i.e. use the arg dim_input elif config.content_type == "distance_moments": dim_input = config.distance_moments feature_function = lambda X, presence: pairwise_distance_moment_features( X, presence, n_moments=config.distance_moments ) else: raise NotImplementedError( f"{config.content_type} featurization not implemented" ) output_dim = config.patterns_reps + 1 torch.manual_seed(config.model_seed) predictor = ConstellationEquivariantTransformer( n_patterns=4, patterns_reps=config.patterns_reps, feature_function=feature_function, group=group, dim_input=dim_input, dim_hidden=config.dim_hidden, num_layers=config.num_layers, num_heads=config.num_heads, # layer_norm=config.layer_norm, global_pool=True, global_pool_mean=config.mean_pooling, liftsamples=config.lift_samples, kernel_dim=config.kernel_dim, kernel_act=config.activation_function, block_norm=config.block_norm, output_norm=config.output_norm, kernel_norm=config.kernel_norm, kernel_type=config.kernel_type, architecture=config.architecture, # batch_norm=config.batch_norm, # location_attention=config.location_attention, attention_fn=config.attention_fn, ) classifier = Classifier(predictor) return classifier, f"ConstellationEquivariantTransformer_{config.group}"
import numpy as np import torch import scipy as sp import scipy.linalg import unittest from lie_conv.lieGroups import SO3,SE3,SE2,SO2 test_groups = [SO2(),SO3(),SE3()] class TestGroups(unittest.TestCase): def test_exp_correct(self,num_trials=3,tol=1e-4): for group in test_groups: for i in np.linspace(-5,2,10): for _ in range(num_trials): w = torch.rand(group.embed_dim)*(10**i) R = group.exp(w).data.numpy() A = group.components2matrix(w).data.numpy() R2 = sp.linalg.expm(A) err = np.abs(R2-R).mean() if err>tol: print(f'{group} exp check failed with {err:.2E} at |w|={w.abs().mean():.2E}') self.assertTrue(err<tol) def test_log_correct(self,num_trials=3,tol=1e-4): for group in test_groups: for i in np.linspace(-2,2,10): for _ in range(num_trials): w = (torch.rand(group.embed_dim)*(10**i)) A = group.components2matrix(w).data.numpy() R = sp.linalg.expm(A) lR = sp.linalg.logm(R,disp=False) logR = group.matrix2components(torch.from_numpy(lR[0].real.astype(np.float32))) logR2 = group.log(torch.from_numpy(R.astype(np.float32))) err = (((logR2-logR).abs()).mean()/logR.abs().mean()).data