Example #1
0
 def __init__(self,
              chin=1,
              total_ds=1 / 64,
              num_layers=6,
              group=T(2),
              fill=1 / 32,
              k=256,
              knn=False,
              nbhd=12,
              num_targets=10,
              increase_channels=True,
              **kwargs):
     ds_frac = (total_ds)**(1 / num_layers)
     fill = [fill / ds_frac**i for i in range(num_layers)]
     if increase_channels:  # whether or not to scale the channels as image is downsampled
         k = [int(k / ds_frac**(i / 2)) for i in range(num_layers + 1)]
     super().__init__(chin=chin,
                      ds_frac=ds_frac,
                      num_layers=num_layers,
                      nbhd=nbhd,
                      mean=True,
                      group=group,
                      fill=fill,
                      k=k,
                      num_outputs=num_targets,
                      cache=True,
                      knn=knn,
                      **kwargs)
     self.lifted_coords = None
Example #2
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    torch.manual_seed(config.model_seed)  # TODO: temp fix of seed
    predictor = MolecLieResNet(
        config.num_species,
        config.charge_scale,
        group=group,
        aug=config.data_augmentation,
        k=config.channels,
        nbhd=config.nbhd_size,
        act=config.activation_function,
        bn=config.batch_norm,
        mean=config.mean_pooling,
        num_layers=config.num_layers,
        fill=config.fill,
        liftsamples=config.lift_samples,
        lie_algebra_nonlinearity=config.lie_algebra_nonlinearity,
    )

    molecule_predictor = MoleculePredictor(predictor, config.task,
                                           config.ds_stats)

    return molecule_predictor, f"MoleculeLieResNet_{config.group}"
Example #3
0
def load(config, **unused_kwargs):

    if config.group == "T(2)":
        group = T(2)
    elif config.group == "T(3)":
        group = T(3)
    elif config.group == "SE(2)":
        group = SE2()
    elif config.group == "SE(2)_canonical":
        group = SE2_canonical()
    elif config.group == "SO(2)":
        group = SO2()
    else:
        raise NotImplementedError(f"Group {config.group} is not implemented.")

    torch.manual_seed(config.model_seed)
    network = DynamicsEquivariantTransformer(
        group=group,
        dim_input=config.sys_dim,
        dim_output=1,  # Potential term in Hamiltonian is scalar
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        architecture=config.architecture,
        attention_fn=config.attention_fn,
    )

    if config.data_config == "configs/dynamics/nbody_dynamics_data.py":
        task = "nbody"
    elif config.data_config == "configs/dynamics/spring_dynamics_data.py":
        task = "spring"

    dynamics_predictor = DynamicsPredictor(network,
                                           debug=config.debug,
                                           task=task)

    return dynamics_predictor, "EqvTransformer_Dynamics"
Example #4
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    torch.manual_seed(config.model_seed)
    predictor = MoleculeEquivariantTransformer(
        config.num_species,
        config.charge_scale,
        architecture=config.architecture,
        group=group,
        aug=config.data_augmentation,
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        fill=config.fill,
        mc_samples=config.mc_samples,
        attention_fn=config.attention_fn,
        feature_embed_dim=config.feature_embed_dim,
        max_sample_norm=config.max_sample_norm,
        lie_algebra_nonlinearity=config.lie_algebra_nonlinearity,
    )

    # predictor.net[-1][-1].weight.data = predictor.net[-1][-1].weight * (0.205 / 0.005)
    # predictor.net[-1][-1].bias.data = predictor.net[-1][-1].bias - (0.196 + 0.40)

    molecule_predictor = MoleculePredictor(predictor, config.task,
                                           config.ds_stats)

    return (
        molecule_predictor,
        f"MoleculeEquivariantTransformer_{config.group}_{config.architecture}",
    )
Example #5
0
 def __init__(self,
              *args,
              group=T(3),
              ds_frac=1,
              fill=1 / 3,
              cache=False,
              knn=False,
              **kwargs):
     kwargs.pop('xyz_dim', None)
     super().__init__(*args,
                      xyz_dim=group.lie_dim + 2 * group.q_dim,
                      **kwargs)
     self.group = group  # Equivariance group for LieConv
     self.register_buffer(
         'r', torch.tensor(2.)
     )  # Internal variable for local_neighborhood radius, set by fill
     self.fill_frac = min(
         fill, 1.
     )  # Average Fraction of the input which enters into local_neighborhood, determines r
     self.knn = knn  # Whether or not to use the k nearest points instead of random samples for conv estimator
     self.subsample = FPSsubsample(ds_frac, cache=cache, group=self.group)
     self.coeff = .5  # Internal coefficient used for updating r
     self.fill_frac_ema = fill  # Keeps track of average fill frac, used for logging only
Example #6
0
def load(config, **unused_kwargs):

    if config.group == "SE3":
        group = SE3(0.2)
    if config.group == "SE2":
        group = SE2(0.2)
    elif config.group == "SO3":
        group = SO3(0.2)
    elif config.group == "T3":
        group = T(3)
    elif config.group == "T2":
        group = T(2)
    elif config.group == "Trivial3":
        group = Trivial(3)
    else:
        raise ValueError(f"{config.group} is and invalid group")

    if config.content_type == "centroidal":
        dim_input = 2
        feature_function = constant_features
    elif config.content_type == "constant":
        dim_input = 1
        feature_function = lambda X, presence: torch.ones(X.shape[:-1], dtype=X.dtype, device=X.device).unsqueeze(-1)
    elif config.content_type == "pairwise_distances":
        dim_input = config.patterns_reps * 17 - 1
        feature_function = pairwise_distance_features  # i.e. use the arg dim_input
    elif config.content_type == "distance_moments":
        dim_input = config.distance_moments
        feature_function = lambda X, presence: pairwise_distance_moment_features(
            X, presence, n_moments=config.distance_moments
        )
    else:
        raise NotImplementedError(
            f"{config.content_type} featurization not implemented"
        )

    output_dim = config.patterns_reps + 1

    torch.manual_seed(config.model_seed)
    predictor = ConstellationEquivariantTransformer(
        n_patterns=4,
        patterns_reps=config.patterns_reps,
        feature_function=feature_function,
        group=group,
        dim_input=dim_input,
        dim_hidden=config.dim_hidden,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        # layer_norm=config.layer_norm,
        global_pool=True,
        global_pool_mean=config.mean_pooling,
        liftsamples=config.lift_samples,
        kernel_dim=config.kernel_dim,
        kernel_act=config.activation_function,
        block_norm=config.block_norm,
        output_norm=config.output_norm,
        kernel_norm=config.kernel_norm,
        kernel_type=config.kernel_type,
        architecture=config.architecture,
        # batch_norm=config.batch_norm,
        # location_attention=config.location_attention,
        attention_fn=config.attention_fn,
    )

    classifier = Classifier(predictor)

    return classifier, f"ConstellationEquivariantTransformer_{config.group}"
Example #7
0
 def __init__(self,d=2,sys_dim=2,bn=False,num_layers=4,group=T(2),k=384,knn=False,nbhd=100,mean=True,**kwargs):
     super().__init__(chin=sys_dim+d,ds_frac=1,num_layers=num_layers,nbhd=nbhd,mean=mean,bn=bn,xyz_dim=d,
                     group=group,fill=1.,k=k,num_outputs=2*d,cache=True,knn=knn,pool=False,**kwargs)
     self.nfe=0
Example #8
0
import lie_conv.moleculeTrainer as moleculeTrainer
import lie_conv.lieGroups as lieGroups
from lie_conv.lieGroups import T,Trivial,SE3,SO3
import lie_conv.lieConv as lieConv
from lie_conv.lieConv import ImgLieResnet
from lie_conv.datasets import MnistRotDataset
from examples.train_molec import makeTrainer,Trial
from oil.tuning.study import Study

def trial_name(cfg):
    ncfg = cfg['net_config']
    return f"molec_f{ncfg['fill']}_n{ncfg['nbhd']}_{ncfg['group']}_{cfg['lr']}"


def bigG(cfg):
    return isinstance(cfg['net_config']['group'],(SE3,SO3))

if __name__ == '__main__':
    config_spec = copy.deepcopy(makeTrainer.__kwdefaults__)
    config_spec.update({
        'num_epochs':500,
        'net_config':{'fill':lambda cfg: (1.,1/2)[bigG(cfg)],'nbhd':lambda cfg: (100,25)[bigG(cfg)],
        'group':T(3),'liftsamples':lambda cfg: (1,4)[bigG(cfg)]},'recenter':lambda cfg: bigG(cfg),
        'lr':3e-3,'bs':lambda cfg: (100,75)[bigG(cfg)],'task':['alpha','gap','h**o','lumo','mu','Cv','G','H','r2','U','U0','zpve'],
        'trainer_config':{'log_dir':'molec_all_tasks4','log_suffix':lambda cfg:trial_name(cfg)},
    })
    config_spec = argupdated_config(config_spec,namespace=(moleculeTrainer,lieGroups))
    thestudy = Study(Trial,config_spec,study_name='molec_all_tasks4')
    thestudy.run(num_trials=-1,ordered=True)
    print(thestudy.results_df())
            cfg['trainer_config']['log_suffix'] = os.path.join(orig_suffix,f'trial{i}/')
        trainer = self.make_trainer(**cfg)
        trainer.logger.add_scalars('config', flatten_dict(cfg))
        trainer.train(cfg['num_epochs'])
        outcome = trainer.logger.scalar_frame.iloc[-1:]
        trainer.logger.save_object(trainer.model.state_dict(),suffix=f'checkpoints/final.state')
        trainer.logger.save_object(trainer.logger.scalar_frame,suffix=f'scalars.df')

        return cfg, outcome

Trial = MiniTrial(makeTrainer)

best_hypers = [
    #{'network':FC,'net_cfg':{'k':256},'lr':3e-3},
    #{'network':HFC,'net_cfg':{'k':256,'num_layers':4},'lr':1e-2},
    {'network':HLieResNet, 'net_cfg':{'k':384, 'num_layers':4, 'group':T(2)}, 'lr':1e-3},
    {'network':AugHLieResNet, 'net_cfg':{'k':384, 'num_layers':4, 'group':T(2)}, 'lr':1e-3},
    #{'network':VOGN,'net_cfg':{'k':512},'lr':3e-3},
    #{'network':HOGN,'net_cfg':{'k':256},'lr':1e-2},
    #{'network':OGN,'net_cfg':{'k':256},'lr':1e-2},
]

if __name__ == '__main__':
    config_spec = copy.deepcopy(makeTrainer.__kwdefaults__)
    config_spec.update({
        'num_epochs':(lambda cfg: int(np.sqrt(1e7/cfg['n_train']))),
        'n_train':[10,25,50,100,400,1000,3000,10000,30000,100000-4000],
    })
    config_spec = argupdated_config(config_spec)
    name = 'aug_data_scaling_dynamics_final'#config_spec.pop('study_name')
    num_repeats = 3#config_spec.pop('num_repeats')