Exemple #1
0
 def test_equivariance(self, tol=1e-4):
     self.model.eval()
     mb = self.mb
     outs = self.model(mb).cpu().data.numpy()
     #print('first done')
     outs2 = self.model(mb).cpu().data.numpy()
     #print('second done')
     bs = mb['positions'].shape[0]
     q = torch.randn(bs,
                     1,
                     4,
                     device=mb['positions'].device,
                     dtype=mb['positions'].dtype)
     q /= norm(q, dim=-1).unsqueeze(-1)
     theta_2 = torch.atan2(norm(q[..., 1:], dim=-1), q[...,
                                                       0]).unsqueeze(-1)
     so3_elem = theta_2 * q[..., 1:]
     Rs = SO3.exp(so3_elem)
     #print(Rs.shape)
     #print(mb['positions'].shape)
     mb['positions'] = (Rs @ mb['positions'].unsqueeze(-1)).squeeze(-1)
     outs3 = self.model(mb).cpu().data.numpy()
     diff = np.abs(outs2 - outs).mean() / np.abs(outs).mean()
     print('run through twice rel err:', diff)
     diff = np.abs(outs2 - outs3).mean() / np.abs(outs2).mean()
     print('rotation equivariance rel err:', diff)
     self.assertTrue(diff < tol)
Exemple #2
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    torch.manual_seed(config.model_seed)  # TODO: temp fix of seed
    predictor = MolecLieResNet(
        config.num_species,
        config.charge_scale,
        group=group,
        aug=config.data_augmentation,
        k=config.channels,
        nbhd=config.nbhd_size,
        act=config.activation_function,
        bn=config.batch_norm,
        mean=config.mean_pooling,
        num_layers=config.num_layers,
        fill=config.fill,
        liftsamples=config.lift_samples,
        lie_algebra_nonlinearity=config.lie_algebra_nonlinearity,
    )

    molecule_predictor = MoleculePredictor(predictor, config.task,
                                           config.ds_stats)

    return molecule_predictor, f"MoleculeLieResNet_{config.group}"
Exemple #3
0
 def forward(self, x):
     if not self.training and self.train_only: return x
     coords, vals, mask = x
     # coords (bs,n,c)
     Rs = SO3().sample(coords.shape[0],
                       1,
                       device=coords.device,
                       dtype=coords.dtype)
     return ((Rs @ coords.unsqueeze(-1)).squeeze(-1), vals, mask)
Exemple #4
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    torch.manual_seed(config.model_seed)
    predictor = MoleculeEquivariantTransformer(
        config.num_species,
        config.charge_scale,
        architecture=config.architecture,
        group=group,
        aug=config.data_augmentation,
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        fill=config.fill,
        mc_samples=config.mc_samples,
        attention_fn=config.attention_fn,
        feature_embed_dim=config.feature_embed_dim,
        max_sample_norm=config.max_sample_norm,
        lie_algebra_nonlinearity=config.lie_algebra_nonlinearity,
    )

    # predictor.net[-1][-1].weight.data = predictor.net[-1][-1].weight * (0.205 / 0.005)
    # predictor.net[-1][-1].bias.data = predictor.net[-1][-1].bias - (0.196 + 0.40)

    molecule_predictor = MoleculePredictor(predictor, config.task,
                                           config.ds_stats)

    return (
        molecule_predictor,
        f"MoleculeEquivariantTransformer_{config.group}_{config.architecture}",
    )
Exemple #5
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    if config.group == "SE2":
        group = SE2(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "T2":
        group = T(2)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    if config.content_type == "centroidal":
        dim_input = 2
        feature_function = constant_features
    elif config.content_type == "constant":
        dim_input = 1
        feature_function = lambda X, presence: torch.ones(X.shape[:-1], dtype=X.dtype, device=X.device).unsqueeze(-1)
    elif config.content_type == "pairwise_distances":
        dim_input = config.patterns_reps * 17 - 1
        feature_function = pairwise_distance_features  # i.e. use the arg dim_input
    elif config.content_type == "distance_moments":
        dim_input = config.distance_moments
        feature_function = lambda X, presence: pairwise_distance_moment_features(
            X, presence, n_moments=config.distance_moments
        )
    else:
        raise NotImplementedError(
            f"{config.content_type} featurization not implemented"
        )

    output_dim = config.patterns_reps + 1

    torch.manual_seed(config.model_seed)
    predictor = ConstellationEquivariantTransformer(
        n_patterns=4,
        patterns_reps=config.patterns_reps,
        feature_function=feature_function,
        group=group,
        dim_input=dim_input,
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        # layer_norm=config.layer_norm,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        architecture=config.architecture,
        # batch_norm=config.batch_norm,
        # location_attention=config.location_attention,
        attention_fn=config.attention_fn,
    )

    classifier = Classifier(predictor)

    return classifier, f"ConstellationEquivariantTransformer_{config.group}"
Exemple #6
0
import numpy as np
import torch
import scipy as sp
import scipy.linalg
import unittest

from lie_conv.lieGroups import SO3,SE3,SE2,SO2
test_groups = [SO2(),SO3(),SE3()]
class TestGroups(unittest.TestCase):
    def test_exp_correct(self,num_trials=3,tol=1e-4):
        for group in test_groups:
            for i in np.linspace(-5,2,10):
                for _ in range(num_trials):
                    w = torch.rand(group.embed_dim)*(10**i)
                    R = group.exp(w).data.numpy()
                    A = group.components2matrix(w).data.numpy()
                    R2 = sp.linalg.expm(A)
                    err = np.abs(R2-R).mean()
                    if err>tol: print(f'{group} exp check failed with {err:.2E} at |w|={w.abs().mean():.2E}')
                    self.assertTrue(err<tol)
    def test_log_correct(self,num_trials=3,tol=1e-4):
        for group in test_groups:
            for i in np.linspace(-2,2,10):
                for _ in range(num_trials):
                    w = (torch.rand(group.embed_dim)*(10**i))
                    A = group.components2matrix(w).data.numpy()
                    R = sp.linalg.expm(A)
                    lR = sp.linalg.logm(R,disp=False)
                    logR = group.matrix2components(torch.from_numpy(lR[0].real.astype(np.float32)))
                    logR2 = group.log(torch.from_numpy(R.astype(np.float32)))
                    err = (((logR2-logR).abs()).mean()/logR.abs().mean()).data