Пример #1
0
    def __init__(
        self,
        maxl,
        max_sh,
        num_cg_levels,
        num_channels,
        num_species,
        cutoff_type,
        hard_cut_rad,
        soft_cut_rad,
        soft_cut_width,
        weight_init,
        level_gain,
        charge_power,
        basis_set,
        charge_scale,
        bag_scale,
        device=None,
        dtype=None,
        cg_dict=None,
    ) -> None:
        # Parameters
        level_gain = expand_var_list(level_gain, num_cg_levels)
        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)
        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels + 1)

        super().__init__(maxl=max(maxl + max_sh),
                         device=device,
                         dtype=dtype,
                         cg_dict=cg_dict)

        self.num_cg_levels = num_cg_levels
        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.bag_scale = bag_scale
        self.num_species = num_species

        # Set up spherical harmonics
        self.sph_harms = SphericalHarmonicsRel(maxl=max(max_sh),
                                               conj=True,
                                               device=self.device,
                                               dtype=self.dtype,
                                               cg_dict=self.cg_dict)

        # Set up position functions, now independent of spherical harmonics
        self.rad_funcs = RadialFilters(
            max_sh=max_sh,
            basis_set=basis_set,
            num_channels_out=num_channels,
            num_levels=num_cg_levels,
            device=self.device,
            dtype=self.dtype,
        )
        tau_pos = self.rad_funcs.tau

        num_scalars_in = self.num_species * (self.charge_power +
                                             1) + self.num_species
        num_scalars_out = num_channels[0]

        self.input_func_atom = InputLinear(num_scalars_in,
                                           num_scalars_out,
                                           device=self.device,
                                           dtype=self.dtype)
        self.input_func_edge = NoLayer()

        tau_in_atom = self.input_func_atom.tau
        tau_in_edge = self.input_func_edge.tau

        self.cormorant_cg = CormorantCG(maxl=maxl,
                                        max_sh=max_sh,
                                        tau_in_atom=tau_in_atom,
                                        tau_in_edge=tau_in_edge,
                                        tau_pos=tau_pos,
                                        num_cg_levels=num_cg_levels,
                                        num_channels=num_channels,
                                        level_gain=level_gain,
                                        weight_init=weight_init,
                                        cutoff_type=cutoff_type,
                                        hard_cut_rad=hard_cut_rad,
                                        soft_cut_rad=soft_cut_rad,
                                        soft_cut_width=soft_cut_width,
                                        cat=True,
                                        gaussian_mask=False,
                                        device=self.device,
                                        dtype=self.dtype,
                                        cg_dict=self.cg_dict)
Пример #2
0
    def __init__(self,
                 maxl,
                 max_sh,
                 num_cg_levels,
                 num_channels,
                 num_species,
                 cutoff_type,
                 hard_cut_rad,
                 soft_cut_rad,
                 soft_cut_width,
                 weight_init,
                 level_gain,
                 charge_power,
                 basis_set,
                 charge_scale,
                 gaussian_mask,
                 num_mpnn_layers=64,
                 activation='leakyrelu',
                 num_classes=2,
                 device=None,
                 dtype=None,
                 cg_dict=None):

        logging.info('Initializing network!')
        level_gain = expand_var_list(level_gain, num_cg_levels)

        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)

        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels + 1)

        logging.info('hard_cut_rad: {}'.format(hard_cut_rad))
        logging.info('soft_cut_rad: {}'.format(soft_cut_rad))
        logging.info('soft_cut_width: {}'.format(soft_cut_width))
        logging.info('maxl: {}'.format(maxl))
        logging.info('max_sh: {}'.format(max_sh))
        logging.info('num_channels: {}'.format(num_channels))

        super().__init__(maxl=max(maxl + max_sh),
                         device=device,
                         dtype=dtype,
                         cg_dict=cg_dict)
        device, dtype, cg_dict = self.device, self.dtype, self.cg_dict

        self.num_cg_levels = num_cg_levels
        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.num_species = num_species

        # Set up spherical harmonics
        self.sph_harms = SphericalHarmonicsRel(max(max_sh),
                                               conj=True,
                                               device=device,
                                               dtype=dtype,
                                               cg_dict=cg_dict)

        # Set up position functions, now independent of spherical harmonics
        self.rad_funcs = RadialFilters(max_sh,
                                       basis_set,
                                       num_channels,
                                       num_cg_levels,
                                       device=self.device,
                                       dtype=self.dtype)
        tau_pos = self.rad_funcs.tau

        # Set up input layers
        num_scalars_in = self.num_species * (self.charge_power + 1)
        num_scalars_out = num_channels[0]
        self.input_func_atom = InputLinear(num_scalars_in,
                                           num_scalars_out,
                                           device=self.device,
                                           dtype=self.dtype)
        self.input_func_edge = NoLayer()

        # Set up the central Clebsch-Gordan network
        tau_in_atom = self.input_func_atom.tau
        tau_in_edge = self.input_func_edge.tau
        self.cormorant_cg = ENN(maxl,
                                max_sh,
                                tau_in_atom,
                                tau_in_edge,
                                tau_pos,
                                num_cg_levels,
                                num_channels,
                                level_gain,
                                weight_init,
                                cutoff_type,
                                hard_cut_rad,
                                soft_cut_rad,
                                soft_cut_width,
                                cat=True,
                                gaussian_mask=False,
                                cgprod_bounded=True,
                                cg_agg_normalization='none',
                                cg_pow_normalization='none',
                                device=self.device,
                                dtype=self.dtype,
                                cg_dict=self.cg_dict)

        # Get atom and edge scalars
        tau_cg_levels_atom = self.cormorant_cg.tau_levels_atom
        tau_cg_levels_edge = self.cormorant_cg.tau_levels_edge
        self.get_scalars_atom = GetScalarsAtom(tau_cg_levels_atom,
                                               device=self.device,
                                               dtype=self.dtype)
        self.get_scalars_edge = NoLayer()

        # Set up the output networks
        num_scalars_atom = self.get_scalars_atom.num_scalars
        num_scalars_edge = self.get_scalars_edge.num_scalars
        self.output_layer_atom = OutputSoftmaxPMLP(num_scalars_atom,
                                                   num_classes,
                                                   num_mixed=num_mpnn_layers,
                                                   activation=activation,
                                                   device=self.device,
                                                   dtype=self.dtype)
        self.output_layer_edge = NoLayer()

        logging.info('Model initialized. Number of parameters: {}'.format(
            sum([p.nelement() for p in self.parameters()])))
Пример #3
0
    def __init__(self,
                 num_cg_levels,
                 maxl,
                 max_sh,
                 num_channels,
                 num_species,
                 cutoff_type,
                 hard_cut_rad,
                 soft_cut_rad,
                 soft_cut_width,
                 weight_init,
                 level_gain,
                 charge_power,
                 basis_set,
                 charge_scale,
                 gaussian_mask,
                 top,
                 input,
                 num_mpnn_layers,
                 num_top_layers,
                 num_scalars_in=None,
                 num_out=1,
                 activation='leakyrelu',
                 additional_atom_features=None,
                 device=torch.device('cpu'),
                 dtype=torch.float):
        super(EdgeCormorant, self).__init__()

        logging.info('Initializing network!')

        self.num_cg_levels = num_cg_levels
        self.maxl = maxl
        self.max_sh = max_sh

        self.device = device
        self.dtype = dtype

        level_gain = expand_var_list(level_gain, num_cg_levels)

        hard_cut_rad = expand_var_list(hard_cut_rad, num_cg_levels)
        soft_cut_rad = expand_var_list(soft_cut_rad, num_cg_levels)
        soft_cut_width = expand_var_list(soft_cut_width, num_cg_levels)

        maxl = expand_var_list(maxl, num_cg_levels)
        max_sh = expand_var_list(max_sh, num_cg_levels)
        num_channels = expand_var_list(num_channels, num_cg_levels)

        self.num_channels = num_channels
        self.charge_power = charge_power
        self.charge_scale = charge_scale
        self.num_species = num_species

        if additional_atom_features is None:
            self.additional_atom_features = []
        else:
            self.additional_atom_features = list(additional_atom_features)

        logging.info('hard_cut_rad: {}'.format(hard_cut_rad))
        logging.info('soft_cut_rad: {}'.format(soft_cut_rad))
        logging.info('soft_cut_width: {}'.format(soft_cut_width))
        logging.info('maxl: {}'.format(maxl))
        logging.info('max_sh: {}'.format(max_sh))
        logging.info('num_channels: {}'.format(num_channels))

        # Set up spherical harmonics
        self.spherical_harmonics_rel = SphericalHarmonicsRel(max(self.max_sh),
                                                             sh_norm='unit',
                                                             device=device,
                                                             dtype=dtype)

        # Set up position functions, now independent of spherical harmonics
        self.position_functions = RadialFilters(max_sh,
                                                basis_set,
                                                num_channels,
                                                num_cg_levels,
                                                device=device,
                                                dtype=dtype)
        tau_pos = self.position_functions.tau

        if num_scalars_in is None:
            num_scalars_in = self.num_species * (self.charge_power + 1)
        else:
            num_scalars_in = num_scalars_in
        num_scalars_out = num_channels[0]

        input = input.lower()
        if input == 'linear':
            self.input_func = InputLinear(num_scalars_in,
                                          num_scalars_out,
                                          device=self.device,
                                          dtype=self.dtype)
        elif input == 'mpnn':
            self.input_func = InputMPNN(num_scalars_in,
                                        num_scalars_out,
                                        num_mpnn_layers,
                                        soft_cut_rad[0],
                                        soft_cut_width[0],
                                        hard_cut_rad[0],
                                        activation=activation,
                                        device=self.device,
                                        dtype=self.dtype)
        else:
            raise ValueError(
                'Improper choice of input featurization of network! {}'.format(
                    input))

        tau_in = [num_scalars_out]

        tau_edge = [0]

        logging.info('{} {}'.format(tau_in, tau_edge))

        atom_levels = nn.ModuleList()
        edge_levels = nn.ModuleList()
        for level in range(self.num_cg_levels):
            # First add the edge, since the output type determines the next level
            edge_lvl = CormorantEdgeLevel(tau_edge,
                                          tau_in,
                                          tau_pos[level],
                                          num_channels[level],
                                          cutoff_type,
                                          hard_cut_rad[level],
                                          soft_cut_rad[level],
                                          soft_cut_width[level],
                                          gaussian_mask=gaussian_mask,
                                          device=device,
                                          dtype=dtype)
            edge_levels.append(edge_lvl)
            tau_edge = edge_lvl.tau_out

            # Now add the NBody level
            atom_lvl = CormorantAtomLevel(tau_in,
                                          tau_edge,
                                          maxl[level],
                                          num_channels[level],
                                          level_gain[level],
                                          weight_init,
                                          device=device,
                                          dtype=dtype)
            atom_levels.append(atom_lvl)
            tau_in = atom_lvl.tau_out

            logging.info('{} {}'.format(tau_in, tau_edge))

        self.atom_levels = atom_levels
        self.edge_levels = edge_levels

        self.tau_levels_out = [level.tau_out for level in edge_levels]
        num_mlp_channels = sum([sum(level) for level in self.tau_levels_out])
        top = top.lower()
        if top == 'linear':
            self.top_func = OutputEdgeLinear(num_mlp_channels,
                                             num_out=num_out,
                                             bias=True,
                                             device=self.device,
                                             dtype=self.dtype)
        elif top == 'mlp':
            self.top_func = OutputEdgeMLP(num_mlp_channels,
                                          num_out=num_out,
                                          num_hidden=num_top_layers,
                                          activation=activation,
                                          device=self.device,
                                          dtype=self.dtype)
        else:
            raise ValueError(
                'Improper choice of top of network! {}'.format(top))

        logging.info('Model initialized. Number of parameters: {}'.format(
            sum([p.nelement() for p in self.parameters()])))