示例#1
0
    def test_covariance(self, tau, num_channels, maxl, sample_batch):
        # setup the environment
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)

        # Build Atom layer
        tlist = [tau] * maxl
        print(tlist)
        atom_lvl = CormorantAtomLevel(tlist,
                                      tlist,
                                      maxl,
                                      num_channels,
                                      1,
                                      'rand',
                                      device=device,
                                      dtype=dtype,
                                      cg_dict=None)

        # Setup Input
        atom_rep, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)

        # Get nonrotated data
        spherical_harmonics, norms = sph_harms(atom_positions, atom_positions)
        edge_rep_list = [
            torch.cat([sph_l] * tau, axis=-3) for sph_l in spherical_harmonics
        ]
        edge_reps = SO3Vec(edge_rep_list)
        print(edge_reps.shapes)
        print(atom_rep.shapes)

        # Get Rotated output
        output = atom_lvl(atom_rep, edge_reps, atom_mask)
        output = output.apply_wigner(D)

        # Get rotated outputdata
        atom_rep_rot = atom_rep.apply_wigner(D)
        spherical_harmonics_rot, norms = sph_harms(atom_positions_rot,
                                                   atom_positions_rot)
        edge_rep_list_rot = [
            torch.cat([sph_l] * tau, axis=-3)
            for sph_l in spherical_harmonics_rot
        ]
        edge_reps_rot = SO3Vec(edge_rep_list_rot)
        output_from_rot = atom_lvl(atom_rep_rot, edge_reps_rot, atom_mask)

        for i in range(maxl):
            assert (torch.max(torch.abs(output_from_rot[i] - output[i])) <
                    1E-5)
示例#2
0
    def __init__(self, maxl, max_sh, num_cg_levels, num_channels, num_species,
                 cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width,
                 weight_init, level_gain, charge_power, basis_set,
                 charge_scale, gaussian_mask,
                 top, input, num_mpnn_layers, activation='leakyrelu',
                 device=None, dtype=None, cg_dict=None):

        logging.info('Initializing network!')
        level_gain = expand_var_list(level_gain, num_cg_levels)

        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)

        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels+1)

        logging.info('hard_cut_rad: {}'.format(hard_cut_rad))
        logging.info('soft_cut_rad: {}'.format(soft_cut_rad))
        logging.info('soft_cut_width: {}'.format(soft_cut_width))
        logging.info('maxl: {}'.format(maxl))
        logging.info('max_sh: {}'.format(max_sh))
        logging.info('num_channels: {}'.format(num_channels))

        super().__init__(maxl=max(maxl+max_sh), device=device, dtype=dtype, cg_dict=cg_dict)
        device, dtype, cg_dict = self.device, self.dtype, self.cg_dict

        print('CGDICT', cg_dict.maxl)

        self.num_cg_levels = num_cg_levels
        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.num_species = num_species

        # Set up spherical harmonics
        self.sph_harms = SphericalHarmonicsRel(max(max_sh), conj=True,
                                               device=device, dtype=dtype, cg_dict=cg_dict)

        # Set up position functions, now independent of spherical harmonics
        self.rad_funcs = RadialFilters(max_sh, basis_set, num_channels, num_cg_levels,
                                       device=self.device, dtype=self.dtype)
        tau_pos = self.rad_funcs.tau

        num_scalars_in = self.num_species * (self.charge_power + 1)
        num_scalars_out = num_channels[0]

        self.input_func_atom = InputMPNN(num_scalars_in, num_scalars_out, num_mpnn_layers,
                                         soft_cut_rad[0], soft_cut_width[0], hard_cut_rad[0],
                                         activation=activation, device=self.device, dtype=self.dtype)
        self.input_func_edge = NoLayer()

        tau_in_atom = self.input_func_atom.tau
        tau_in_edge = self.input_func_edge.tau

        self.cormorant_cg = ENN(maxl, max_sh, tau_in_atom, tau_in_edge,
                     tau_pos, num_cg_levels, num_channels, level_gain, weight_init,
                     cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width,
                     cat=True, gaussian_mask=False,
                     device=self.device, dtype=self.dtype, cg_dict=self.cg_dict)

        tau_cg_levels_atom = self.cormorant_cg.tau_levels_atom
        tau_cg_levels_edge = self.cormorant_cg.tau_levels_edge

        self.get_scalars_atom = GetScalarsAtom(tau_cg_levels_atom,
                                               device=self.device, dtype=self.dtype)
        self.get_scalars_edge = NoLayer()

        num_scalars_atom = self.get_scalars_atom.num_scalars
        num_scalars_edge = self.get_scalars_edge.num_scalars

        self.output_layer_atom = OutputPMLP(num_scalars_atom, activation=activation,
                                            device=self.device, dtype=self.dtype)
        self.output_layer_edge = NoLayer()

        logging.info('Model initialized. Number of parameters: {}'.format(
            sum([p.nelement() for p in self.parameters()])))
示例#3
0
文件: modules.py 项目: gncs/molgym
    def __init__(
        self,
        maxl,
        max_sh,
        num_cg_levels,
        num_channels,
        num_species,
        cutoff_type,
        hard_cut_rad,
        soft_cut_rad,
        soft_cut_width,
        weight_init,
        level_gain,
        charge_power,
        basis_set,
        charge_scale,
        bag_scale,
        device=None,
        dtype=None,
        cg_dict=None,
    ) -> None:
        # Parameters
        level_gain = expand_var_list(level_gain, num_cg_levels)
        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)
        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels + 1)

        super().__init__(maxl=max(maxl + max_sh),
                         device=device,
                         dtype=dtype,
                         cg_dict=cg_dict)

        self.num_cg_levels = num_cg_levels
        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.bag_scale = bag_scale
        self.num_species = num_species

        # Set up spherical harmonics
        self.sph_harms = SphericalHarmonicsRel(maxl=max(max_sh),
                                               conj=True,
                                               device=self.device,
                                               dtype=self.dtype,
                                               cg_dict=self.cg_dict)

        # Set up position functions, now independent of spherical harmonics
        self.rad_funcs = RadialFilters(
            max_sh=max_sh,
            basis_set=basis_set,
            num_channels_out=num_channels,
            num_levels=num_cg_levels,
            device=self.device,
            dtype=self.dtype,
        )
        tau_pos = self.rad_funcs.tau

        num_scalars_in = self.num_species * (self.charge_power +
                                             1) + self.num_species
        num_scalars_out = num_channels[0]

        self.input_func_atom = InputLinear(num_scalars_in,
                                           num_scalars_out,
                                           device=self.device,
                                           dtype=self.dtype)
        self.input_func_edge = NoLayer()

        tau_in_atom = self.input_func_atom.tau
        tau_in_edge = self.input_func_edge.tau

        self.cormorant_cg = CormorantCG(maxl=maxl,
                                        max_sh=max_sh,
                                        tau_in_atom=tau_in_atom,
                                        tau_in_edge=tau_in_edge,
                                        tau_pos=tau_pos,
                                        num_cg_levels=num_cg_levels,
                                        num_channels=num_channels,
                                        level_gain=level_gain,
                                        weight_init=weight_init,
                                        cutoff_type=cutoff_type,
                                        hard_cut_rad=hard_cut_rad,
                                        soft_cut_rad=soft_cut_rad,
                                        soft_cut_width=soft_cut_width,
                                        cat=True,
                                        gaussian_mask=False,
                                        device=self.device,
                                        dtype=self.dtype,
                                        cg_dict=self.cg_dict)
    def __init__(self,
                 num_cg_levels,
                 maxl,
                 max_sh,
                 num_channels,
                 num_species,
                 cutoff_type,
                 hard_cut_rad,
                 soft_cut_rad,
                 soft_cut_width,
                 weight_init,
                 level_gain,
                 charge_power,
                 basis_set,
                 charge_scale,
                 gaussian_mask,
                 top,
                 input,
                 num_mpnn_layers,
                 num_top_layers,
                 num_scalars_in=None,
                 num_out=1,
                 activation='leakyrelu',
                 additional_atom_features=None,
                 device=torch.device('cpu'),
                 dtype=torch.float):
        super(EdgeCormorant, self).__init__()

        logging.info('Initializing network!')

        self.num_cg_levels = num_cg_levels
        self.maxl = maxl
        self.max_sh = max_sh

        self.device = device
        self.dtype = dtype

        level_gain = expand_var_list(level_gain, num_cg_levels)

        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)

        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels)

        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.num_species = num_species

        if additional_atom_features is None:
            self.additional_atom_features = []
        else:
            self.additional_atom_features = list(additional_atom_features)

        logging.info('hard_cut_rad: {}'.format(hard_cut_rad))
        logging.info('soft_cut_rad: {}'.format(soft_cut_rad))
        logging.info('soft_cut_width: {}'.format(soft_cut_width))
        logging.info('maxl: {}'.format(maxl))
        logging.info('max_sh: {}'.format(max_sh))
        logging.info('num_channels: {}'.format(num_channels))

        # Set up spherical harmonics
        self.spherical_harmonics_rel = SphericalHarmonicsRel(max(self.max_sh),
                                                             sh_norm='unit',
                                                             device=device,
                                                             dtype=dtype)

        # Set up position functions, now independent of spherical harmonics
        self.position_functions = RadialFilters(max_sh,
                                                basis_set,
                                                num_channels,
                                                num_cg_levels,
                                                device=device,
                                                dtype=dtype)
        tau_pos = self.position_functions.tau

        if num_scalars_in is None:
            num_scalars_in = self.num_species * (self.charge_power + 1)
        else:
            num_scalars_in = num_scalars_in
        num_scalars_out = num_channels[0]

        input = input.lower()
        if input == 'linear':
            self.input_func = InputLinear(num_scalars_in,
                                          num_scalars_out,
                                          device=self.device,
                                          dtype=self.dtype)
        elif input == 'mpnn':
            self.input_func = InputMPNN(num_scalars_in,
                                        num_scalars_out,
                                        num_mpnn_layers,
                                        soft_cut_rad[0],
                                        soft_cut_width[0],
                                        hard_cut_rad[0],
                                        activation=activation,
                                        device=self.device,
                                        dtype=self.dtype)
        else:
            raise ValueError(
                'Improper choice of input featurization of network! {}'.format(
                    input))

        tau_in = [num_scalars_out]

        tau_edge = [0]

        logging.info('{} {}'.format(tau_in, tau_edge))

        atom_levels = nn.ModuleList()
        edge_levels = nn.ModuleList()
        for level in range(self.num_cg_levels):
            # First add the edge, since the output type determines the next level
            edge_lvl = CormorantEdgeLevel(tau_edge,
                                          tau_in,
                                          tau_pos[level],
                                          num_channels[level],
                                          cutoff_type,
                                          hard_cut_rad[level],
                                          soft_cut_rad[level],
                                          soft_cut_width[level],
                                          gaussian_mask=gaussian_mask,
                                          device=device,
                                          dtype=dtype)
            edge_levels.append(edge_lvl)
            tau_edge = edge_lvl.tau_out

            # Now add the NBody level
            atom_lvl = CormorantAtomLevel(tau_in,
                                          tau_edge,
                                          maxl[level],
                                          num_channels[level],
                                          level_gain[level],
                                          weight_init,
                                          device=device,
                                          dtype=dtype)
            atom_levels.append(atom_lvl)
            tau_in = atom_lvl.tau_out

            logging.info('{} {}'.format(tau_in, tau_edge))

        self.atom_levels = atom_levels
        self.edge_levels = edge_levels

        self.tau_levels_out = [level.tau_out for level in edge_levels]
        num_mlp_channels = sum([sum(level) for level in self.tau_levels_out])
        top = top.lower()
        if top == 'linear':
            self.top_func = OutputEdgeLinear(num_mlp_channels,
                                             num_out=num_out,
                                             bias=True,
                                             device=self.device,
                                             dtype=self.dtype)
        elif top == 'mlp':
            self.top_func = OutputEdgeMLP(num_mlp_channels,
                                          num_out=num_out,
                                          num_hidden=num_top_layers,
                                          activation=activation,
                                          device=self.device,
                                          dtype=self.dtype)
        else:
            raise ValueError(
                'Improper choice of top of network! {}'.format(top))

        logging.info('Model initialized. Number of parameters: {}'.format(
            sum([p.nelement() for p in self.parameters()])))
示例#5
0
文件: model.py 项目: maschka/atom3d
    def __init__(self, maxl, max_sh, num_cg_levels, num_channels, num_species,
                 cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width,
                 weight_init, level_gain, charge_power, basis_set,
                 charge_scale, gaussian_mask, cgprod_bounded=True,
                 cg_pow_normalization='none', cg_agg_normalization='none', 
                 device=None, dtype=None, cg_dict=None):

        logging.info('Initializing network!')
        level_gain = expand_var_list(level_gain, num_cg_levels)

        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)

        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels+1)

        logging.info('hard_cut_rad: {}'.format(hard_cut_rad))
        logging.info('soft_cut_rad: {}'.format(soft_cut_rad))
        logging.info('soft_cut_width: {}'.format(soft_cut_width))
        logging.info('maxl: {}'.format(maxl))
        logging.info('max_sh: {}'.format(max_sh))
        logging.info('num_channels: {}'.format(num_channels))

        super().__init__(maxl=max(maxl+max_sh), device=device, dtype=dtype, cg_dict=cg_dict)
        device, dtype, cg_dict = self.device, self.dtype, self.cg_dict

        self.num_cg_levels = num_cg_levels
        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.num_species = num_species

        # Set up spherical harmonics
        self.sph_harms = SphericalHarmonicsRel(max(max_sh), conj=True,
                                               device=device, dtype=dtype, cg_dict=cg_dict)

        # Set up position functions, now independent of spherical harmonics
        self.rad_funcs = RadialFilters(max_sh, basis_set, num_channels, num_cg_levels,
                                       device=self.device, dtype=self.dtype)
        tau_pos = self.rad_funcs.tau

        # Set up input layers
        num_scalars_in = self.num_species * (self.charge_power + 1)
        num_scalars_out = num_channels[0]
        self.input_func_atom = InputLinear(num_scalars_in, num_scalars_out,
                                           device=self.device, dtype=self.dtype)
        self.input_func_edge = NoLayer()

        # Set up the central Clebsch-Gordan network
        tau_in_atom = self.input_func_atom.tau
        tau_in_edge = self.input_func_edge.tau
        self.cormorant_cg = ENN(maxl, max_sh, tau_in_atom, tau_in_edge,
                     tau_pos, num_cg_levels, num_channels, level_gain, weight_init,
                     cutoff_type, hard_cut_rad, soft_cut_rad, soft_cut_width,
                     cgprod_bounded=cgprod_bounded, 
                     cg_pow_normalization=cg_pow_normalization, 
                     cg_agg_normalization=cg_agg_normalization,
                     device=self.device, dtype=self.dtype, cg_dict=self.cg_dict)

        # Get atom and edge scalars
        tau_cg_levels_atom = self.cormorant_cg.tau_levels_atom
        tau_cg_levels_edge = self.cormorant_cg.tau_levels_edge
        self.get_scalars_atom = GetScalarsAtom(tau_cg_levels_atom,
                                               device=self.device, dtype=self.dtype)
        self.get_scalars_edge = NoLayer()

        # Set up the output networks
        num_scalars_atom = self.get_scalars_atom.num_scalars
        num_scalars_edge = self.get_scalars_edge.num_scalars
        self.output_layer_atom = OutputLinear(num_scalars_atom, bias=True,
                                              device=self.device, dtype=self.dtype)
        self.output_layer_edge = NoLayer()

        logging.info('Model initialized. Number of parameters: {}'.format(
            sum([p.nelement() for p in self.parameters()])))
示例#6
0
    def test_covariance(self, tau, num_channels, maxl, basis, edge_net_type,
                        sample_batch):
        # env = build_environment(tau, maxl, num_channels)
        # datasets, data, num_species, charge_scale, sph_harms = env
        data, __, __ = sample_batch
        device, dtype = data['positions'].device, data['positions'].dtype
        sph_harms = SphericalHarmonicsRel(maxl - 1,
                                          conj=True,
                                          device=device,
                                          dtype=dtype,
                                          cg_dict=None)
        batch_size, natoms = data['positions'].shape[:2]
        D, R, __ = rot.gen_rot(maxl, device=device, dtype=dtype)
        # Setup Input
        atom_reps, atom_mask, edge_scalars, edge_mask, atom_positions = prep_input(
            data, tau, maxl)
        atom_positions_rot = rot.rotate_cart_vec(R, atom_positions)
        atom_reps_rot = atom_reps.apply_wigner(D)

        # Calculate spherical harmonics and radial functions
        __, norms = sph_harms(atom_positions, atom_positions)
        __, norms_rot = sph_harms(atom_positions_rot, atom_positions_rot)

        rad_funcs = RadialFilters([maxl - 1], [basis, basis], [num_channels],
                                  1,
                                  device=device,
                                  dtype=dtype)
        rad_func_levels = rad_funcs(norms, edge_mask * (norms > 0))
        tau_pos = rad_funcs.tau[0]

        # Build the initial edge network
        if edge_net_type is None:
            edge_reps = None
        elif edge_net_type == 'rand':
            reps = [
                torch.randn((batch_size, natoms, natoms, tau, 2))
                for i in range(maxl)
            ]
            edge_reps = SO3Scalar(reps)
        else:
            raise ValueError

        # Build Edge layer
        tlist = [tau] * maxl
        tau_atoms = tlist
        tau_edge = tlist
        if edge_net_type is None:
            tau_edge = []

        edge_lvl = CormorantEdgeLevel(tau_atoms,
                                      tau_edge,
                                      tau_pos,
                                      num_channels,
                                      maxl,
                                      cutoff_type='soft',
                                      device=device,
                                      dtype=dtype,
                                      hard_cut_rad=1.73,
                                      soft_cut_rad=1.73,
                                      soft_cut_width=0.2)

        output_edge_reps = edge_lvl(edge_reps, atom_reps, rad_func_levels[0],
                                    edge_mask, norms)
        output_edge_reps_rot = edge_lvl(edge_reps, atom_reps_rot,
                                        rad_func_levels[0], edge_mask, norms)

        for i in range(maxl):
            assert (torch.max(
                torch.abs(output_edge_reps[i] - output_edge_reps_rot[i])) <
                    1E-5)