def test_covariance(self, tau, num_channels, maxl, sample_batch): # setup the environment # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Build Atom layer tlist = [tau] * maxl print(tlist) atom_lvl = CormorantAtomLevel(tlist, tlist, maxl, num_channels, 1, 'rand', device=device, dtype=dtype, cg_dict=None) # Setup Input atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) # Get nonrotated data spherical_harmonics, norms = sph_harms(atom_positions, atom_positions) edge_rep_list = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics ] edge_reps = SO3Vec(edge_rep_list) print(edge_reps.shapes) print(atom_rep.shapes) # Get Rotated output output = atom_lvl(atom_rep, edge_reps, atom_mask) output = output.apply_wigner(D) # Get rotated outputdata atom_rep_rot = atom_rep.apply_wigner(D) spherical_harmonics_rot, norms = sph_harms(atom_positions_rot, atom_positions_rot) edge_rep_list_rot = [ torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics_rot ] edge_reps_rot = SO3Vec(edge_rep_list_rot) output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask) for i in range(maxl): assert (torch.max(torch.abs(output_from_rot[i] - output[i])) < 1E-5)
def __init__(self, maxl, max_sh, num_cg_levels, num_channels, num_species, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, weight_init, level_gain, charge_power, basis_set, charge_scale, gaussian_mask, top, input, num_mpnn_layers, activation='leakyrelu', device=None, dtype=None, cg_dict=None): logging.info('Initializing network!') level_gain = expand_var_list(level_gain, num_cg_levels) hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels) soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels) soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels) maxl = expand_var_list(maxl, num_cg_levels) max_sh = expand_var_list(max_sh, num_cg_levels) num_channels = expand_var_list(num_channels, num_cg_levels+1) logging.info('hard_cut_rad: {}'.format(hard_cut_rad)) logging.info('soft_cut_rad: {}'.format(soft_cut_rad)) logging.info('soft_cut_width: {}'.format(soft_cut_width)) logging.info('maxl: {}'.format(maxl)) logging.info('max_sh: {}'.format(max_sh)) logging.info('num_channels: {}'.format(num_channels)) super().__init__(maxl=max(maxl+max_sh), device=device, dtype=dtype, cg_dict=cg_dict) device, dtype, cg_dict = self.device, self.dtype, self.cg_dict print('CGDICT', cg_dict.maxl) self.num_cg_levels = num_cg_levels self.num_channels = num_channels self.charge_power = charge_power self.charge_scale = charge_scale self.num_species = num_species # Set up spherical harmonics self.sph_harms = SphericalHarmonicsRel(max(max_sh), conj=True, device=device, dtype=dtype, cg_dict=cg_dict) # Set up position functions, now independent of spherical harmonics self.rad_funcs = RadialFilters(max_sh, basis_set, num_channels, num_cg_levels, device=self.device, dtype=self.dtype) tau_pos = self.rad_funcs.tau num_scalars_in = self.num_species * (self.charge_power + 1) num_scalars_out = num_channels[0] self.input_func_atom = InputMPNN(num_scalars_in, num_scalars_out, num_mpnn_layers, soft_cut_rad[0], soft_cut_width[0], hard_cut_rad[0], activation=activation, device=self.device, dtype=self.dtype) self.input_func_edge = NoLayer() tau_in_atom = self.input_func_atom.tau tau_in_edge = self.input_func_edge.tau self.cormorant_cg = ENN(maxl, max_sh, tau_in_atom, tau_in_edge, tau_pos, num_cg_levels, num_channels, level_gain, weight_init, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, cat=True, gaussian_mask=False, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) tau_cg_levels_atom = self.cormorant_cg.tau_levels_atom tau_cg_levels_edge = self.cormorant_cg.tau_levels_edge self.get_scalars_atom = GetScalarsAtom(tau_cg_levels_atom, device=self.device, dtype=self.dtype) self.get_scalars_edge = NoLayer() num_scalars_atom = self.get_scalars_atom.num_scalars num_scalars_edge = self.get_scalars_edge.num_scalars self.output_layer_atom = OutputPMLP(num_scalars_atom, activation=activation, device=self.device, dtype=self.dtype) self.output_layer_edge = NoLayer() logging.info('Model initialized. Number of parameters: {}'.format( sum([p.nelement() for p in self.parameters()])))
def __init__( self, maxl, max_sh, num_cg_levels, num_channels, num_species, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, weight_init, level_gain, charge_power, basis_set, charge_scale, bag_scale, device=None, dtype=None, cg_dict=None, ) -> None: # Parameters level_gain = expand_var_list(level_gain, num_cg_levels) hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels) soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels) soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels) maxl = expand_var_list(maxl, num_cg_levels) max_sh = expand_var_list(max_sh, num_cg_levels) num_channels = expand_var_list(num_channels, num_cg_levels + 1) super().__init__(maxl=max(maxl + max_sh), device=device, dtype=dtype, cg_dict=cg_dict) self.num_cg_levels = num_cg_levels self.num_channels = num_channels self.charge_power = charge_power self.charge_scale = charge_scale self.bag_scale = bag_scale self.num_species = num_species # Set up spherical harmonics self.sph_harms = SphericalHarmonicsRel(maxl=max(max_sh), conj=True, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) # Set up position functions, now independent of spherical harmonics self.rad_funcs = RadialFilters( max_sh=max_sh, basis_set=basis_set, num_channels_out=num_channels, num_levels=num_cg_levels, device=self.device, dtype=self.dtype, ) tau_pos = self.rad_funcs.tau num_scalars_in = self.num_species * (self.charge_power + 1) + self.num_species num_scalars_out = num_channels[0] self.input_func_atom = InputLinear(num_scalars_in, num_scalars_out, device=self.device, dtype=self.dtype) self.input_func_edge = NoLayer() tau_in_atom = self.input_func_atom.tau tau_in_edge = self.input_func_edge.tau self.cormorant_cg = CormorantCG(maxl=maxl, max_sh=max_sh, tau_in_atom=tau_in_atom, tau_in_edge=tau_in_edge, tau_pos=tau_pos, num_cg_levels=num_cg_levels, num_channels=num_channels, level_gain=level_gain, weight_init=weight_init, cutoff_type=cutoff_type, hard_cut_rad=hard_cut_rad, soft_cut_rad=soft_cut_rad, soft_cut_width=soft_cut_width, cat=True, gaussian_mask=False, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict)
def __init__(self, num_cg_levels, maxl, max_sh, num_channels, num_species, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, weight_init, level_gain, charge_power, basis_set, charge_scale, gaussian_mask, top, input, num_mpnn_layers, num_top_layers, num_scalars_in=None, num_out=1, activation='leakyrelu', additional_atom_features=None, device=torch.device('cpu'), dtype=torch.float): super(EdgeCormorant, self).__init__() logging.info('Initializing network!') self.num_cg_levels = num_cg_levels self.maxl = maxl self.max_sh = max_sh self.device = device self.dtype = dtype level_gain = expand_var_list(level_gain, num_cg_levels) hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels) soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels) soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels) maxl = expand_var_list(maxl, num_cg_levels) max_sh = expand_var_list(max_sh, num_cg_levels) num_channels = expand_var_list(num_channels, num_cg_levels) self.num_channels = num_channels self.charge_power = charge_power self.charge_scale = charge_scale self.num_species = num_species if additional_atom_features is None: self.additional_atom_features = [] else: self.additional_atom_features = list(additional_atom_features) logging.info('hard_cut_rad: {}'.format(hard_cut_rad)) logging.info('soft_cut_rad: {}'.format(soft_cut_rad)) logging.info('soft_cut_width: {}'.format(soft_cut_width)) logging.info('maxl: {}'.format(maxl)) logging.info('max_sh: {}'.format(max_sh)) logging.info('num_channels: {}'.format(num_channels)) # Set up spherical harmonics self.spherical_harmonics_rel = SphericalHarmonicsRel(max(self.max_sh), sh_norm='unit', device=device, dtype=dtype) # Set up position functions, now independent of spherical harmonics self.position_functions = RadialFilters(max_sh, basis_set, num_channels, num_cg_levels, device=device, dtype=dtype) tau_pos = self.position_functions.tau if num_scalars_in is None: num_scalars_in = self.num_species * (self.charge_power + 1) else: num_scalars_in = num_scalars_in num_scalars_out = num_channels[0] input = input.lower() if input == 'linear': self.input_func = InputLinear(num_scalars_in, num_scalars_out, device=self.device, dtype=self.dtype) elif input == 'mpnn': self.input_func = InputMPNN(num_scalars_in, num_scalars_out, num_mpnn_layers, soft_cut_rad[0], soft_cut_width[0], hard_cut_rad[0], activation=activation, device=self.device, dtype=self.dtype) else: raise ValueError( 'Improper choice of input featurization of network! {}'.format( input)) tau_in = [num_scalars_out] tau_edge = [0] logging.info('{} {}'.format(tau_in, tau_edge)) atom_levels = nn.ModuleList() edge_levels = nn.ModuleList() for level in range(self.num_cg_levels): # First add the edge, since the output type determines the next level edge_lvl = CormorantEdgeLevel(tau_edge, tau_in, tau_pos[level], num_channels[level], cutoff_type, hard_cut_rad[level], soft_cut_rad[level], soft_cut_width[level], gaussian_mask=gaussian_mask, device=device, dtype=dtype) edge_levels.append(edge_lvl) tau_edge = edge_lvl.tau_out # Now add the NBody level atom_lvl = CormorantAtomLevel(tau_in, tau_edge, maxl[level], num_channels[level], level_gain[level], weight_init, device=device, dtype=dtype) atom_levels.append(atom_lvl) tau_in = atom_lvl.tau_out logging.info('{} {}'.format(tau_in, tau_edge)) self.atom_levels = atom_levels self.edge_levels = edge_levels self.tau_levels_out = [level.tau_out for level in edge_levels] num_mlp_channels = sum([sum(level) for level in self.tau_levels_out]) top = top.lower() if top == 'linear': self.top_func = OutputEdgeLinear(num_mlp_channels, num_out=num_out, bias=True, device=self.device, dtype=self.dtype) elif top == 'mlp': self.top_func = OutputEdgeMLP(num_mlp_channels, num_out=num_out, num_hidden=num_top_layers, activation=activation, device=self.device, dtype=self.dtype) else: raise ValueError( 'Improper choice of top of network! {}'.format(top)) logging.info('Model initialized. Number of parameters: {}'.format( sum([p.nelement() for p in self.parameters()])))
def __init__(self, maxl, max_sh, num_cg_levels, num_channels, num_species, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, weight_init, level_gain, charge_power, basis_set, charge_scale, gaussian_mask, cgprod_bounded=True, cg_pow_normalization='none', cg_agg_normalization='none', device=None, dtype=None, cg_dict=None): logging.info('Initializing network!') level_gain = expand_var_list(level_gain, num_cg_levels) hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels) soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels) soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels) maxl = expand_var_list(maxl, num_cg_levels) max_sh = expand_var_list(max_sh, num_cg_levels) num_channels = expand_var_list(num_channels, num_cg_levels+1) logging.info('hard_cut_rad: {}'.format(hard_cut_rad)) logging.info('soft_cut_rad: {}'.format(soft_cut_rad)) logging.info('soft_cut_width: {}'.format(soft_cut_width)) logging.info('maxl: {}'.format(maxl)) logging.info('max_sh: {}'.format(max_sh)) logging.info('num_channels: {}'.format(num_channels)) super().__init__(maxl=max(maxl+max_sh), device=device, dtype=dtype, cg_dict=cg_dict) device, dtype, cg_dict = self.device, self.dtype, self.cg_dict self.num_cg_levels = num_cg_levels self.num_channels = num_channels self.charge_power = charge_power self.charge_scale = charge_scale self.num_species = num_species # Set up spherical harmonics self.sph_harms = SphericalHarmonicsRel(max(max_sh), conj=True, device=device, dtype=dtype, cg_dict=cg_dict) # Set up position functions, now independent of spherical harmonics self.rad_funcs = RadialFilters(max_sh, basis_set, num_channels, num_cg_levels, device=self.device, dtype=self.dtype) tau_pos = self.rad_funcs.tau # Set up input layers num_scalars_in = self.num_species * (self.charge_power + 1) num_scalars_out = num_channels[0] self.input_func_atom = InputLinear(num_scalars_in, num_scalars_out, device=self.device, dtype=self.dtype) self.input_func_edge = NoLayer() # Set up the central Clebsch-Gordan network tau_in_atom = self.input_func_atom.tau tau_in_edge = self.input_func_edge.tau self.cormorant_cg = ENN(maxl, max_sh, tau_in_atom, tau_in_edge, tau_pos, num_cg_levels, num_channels, level_gain, weight_init, cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width, cgprod_bounded=cgprod_bounded, cg_pow_normalization=cg_pow_normalization, cg_agg_normalization=cg_agg_normalization, device=self.device, dtype=self.dtype, cg_dict=self.cg_dict) # Get atom and edge scalars tau_cg_levels_atom = self.cormorant_cg.tau_levels_atom tau_cg_levels_edge = self.cormorant_cg.tau_levels_edge self.get_scalars_atom = GetScalarsAtom(tau_cg_levels_atom, device=self.device, dtype=self.dtype) self.get_scalars_edge = NoLayer() # Set up the output networks num_scalars_atom = self.get_scalars_atom.num_scalars num_scalars_edge = self.get_scalars_edge.num_scalars self.output_layer_atom = OutputLinear(num_scalars_atom, bias=True, device=self.device, dtype=self.dtype) self.output_layer_edge = NoLayer() logging.info('Model initialized. Number of parameters: {}'.format( sum([p.nelement() for p in self.parameters()])))
def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type, sample_batch): # env = build_environment(tau, maxl, num_channels) # datasets, data, num_species, charge_scale, sph_harms = env data, __, __ = sample_batch device, dtype = data['positions'].device, data['positions'].dtype sph_harms = SphericalHarmonicsRel(maxl - 1, conj=True, device=device, dtype=dtype, cg_dict=None) batch_size, natoms = data['positions'].shape[:2] D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype) # Setup Input atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input( data, tau, maxl) atom_positions_rot = rot.rotate_cart_vec(R, atom_positions) atom_reps_rot = atom_reps.apply_wigner(D) # Calculate spherical harmonics and radial functions __, norms = sph_harms(atom_positions, atom_positions) __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot) rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels], 1, device=device, dtype=dtype) rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0)) tau_pos = rad_funcs.tau[0] # Build the initial edge network if edge_net_type is None: edge_reps = None elif edge_net_type == 'rand': reps = [ torch.randn((batch_size, natoms, natoms, tau, 2)) for i in range(maxl) ] edge_reps = SO3Scalar(reps) else: raise ValueError # Build Edge layer tlist = [tau] * maxl tau_atoms = tlist tau_edge = tlist if edge_net_type is None: tau_edge = [] edge_lvl = CormorantEdgeLevel(tau_atoms, tau_edge, tau_pos, num_channels, maxl, cutoff_type='soft', device=device, dtype=dtype, hard_cut_rad=1.73, soft_cut_rad=1.73, soft_cut_width=0.2) output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0], edge_mask, norms) output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot, rad_func_levels[0], edge_mask, norms) for i in range(maxl): assert (torch.max( torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) < 1E-5)