Ejemplo n.º 1
0
    def __init__(self, architecture, device):
        from a2mdnet.a2mdt.modules import A2MDt
        super(MultTaskDensityCoupled, self).__init__()

        self.device = device
        try:

            common_net_architecture = architecture['common_net']
            atom_net_architecture = architecture['atom_net']
            mol_net_architecture = architecture['mol_net']
            bond_net_architecture = architecture['bond_net']
            feature_extraction_layer = architecture['fe_layer']
            feature_extraction_net = architecture['fe_net']
            density_function_iso_order = architecture['iso_order']
            density_function_aniso_order = architecture['aniso_order']
            coefficients_distribution = architecture[
                'coefficients_distribution']

        except KeyError:
            raise IOError("architecture dict did not contain the right fields")

        self.common_net = modules.TorchElementSpecificA2MDNN(
            nodes=common_net_architecture,
            elements=[1, 6, 7, 8],
            device=self.device)

        self.mol_net = modules.TorchElementSpecificA2MDNN(
            nodes=mol_net_architecture,
            elements=[1, 6, 7, 8],
            device=self.device,
        )

        self.atom_net = modules.TorchElementSpecificA2MDNN(
            nodes=atom_net_architecture,
            elements=[1, 6, 7, 8],
            device=self.device,
            distribution=coefficients_distribution['isotropic'])

        self.bond_net = modules.TorchPairSpecificA2MDNN(
            nodes=bond_net_architecture,
            elements=[1, 6, 7, 8],
            device=self.device,
            distribution=coefficients_distribution['anisotropic'])

        if feature_extraction_layer is None:

            self.feature_extractor = modules.SymFeats()

        else:

            self.feature_extractor = modules.TorchaniFeats(
                net=feature_extraction_net,
                feats_layer=feature_extraction_layer,
                device=self.device)

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.normalizer = modules.QPChargeNormalization(device=self.device)
        self.density_model = A2MDt(params=A2MD_MODEL, device=self.device)
Ejemplo n.º 2
0
    def __init__(self, architecture):

        super(NormA2MDDensity, self).__init__()

        # self.normalize = normalize

        try:

            common_net_architecture = architecture['common_net']
            atom_net_architecture = architecture['atom_net']
            bond_net_architecture = architecture['bond_net']
            subnet = architecture['subnet']

        except KeyError:
            raise IOError("architecture dict did not contain the right fields")

        self.common_net = modules.TorchElementSpecificA2MDNN(
            nodes=common_net_architecture,
            elements=[1, 6, 7, 8],
            device=DEVICE)
        self.atom_net = modules.TorchElementSpecificA2MDNN(
            nodes=atom_net_architecture, elements=[1, 6, 7, 8], device=DEVICE)
        self.bond_net = modules.TorchPairSpecificA2MDNN(
            nodes=bond_net_architecture, elements=[1, 6, 7, 8])
        self.feature_extractor = modules.TorchaniFeats(net=subnet,
                                                       device=DEVICE)

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.normalizer = modules.QPChargeNormalization()
Ejemplo n.º 3
0
    def __init__(self, architecture):

        super(Baseline, self).__init__()

        # self.normalize = normalize

        try:

            common_net_architecture = architecture['common_net']
            atom_net_architecture = architecture['atom_net']
            bond_net_architecture = architecture['bond_net']

        except KeyError:
            raise IOError("architecture dict did not contain the right fields")

        self.common_net = modules.TorchElementSpecificA2MDNN(
            nodes=common_net_architecture,
            elements=[1, 6, 7, 8],
            device=DEVICE)
        self.atom_net = modules.TorchElementSpecificA2MDNN(
            nodes=atom_net_architecture, elements=[1, 6, 7, 8], device=DEVICE)
        self.bond_net = modules.TorchPairDistances(nodes=bond_net_architecture,
                                                   elements=[1, 6, 7, 8],
                                                   device=DEVICE)
        self.feature_extractor = modules.SymFeats()

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.normalizer = modules.QPChargeNormalization(device=DEVICE)
        self.pair_distances = APEV(output_size=10,
                                   parameters=PAIR_FEATURES,
                                   device=DEVICE)