示例#1
0
def test_smear_gaussian_trainable():
    dist = torch.tensor([[[0.0, 1.0, 1.5, 0.25], [0.5, 1.5, 3.0, 1.0]]])
    # smear using 5 Gaussian functions with 0.75 spacing
    smear = GaussianSmearing(start=1., stop=4., n_gaussians=5, trainable=True)
    # absolute value of centered distances
    expt = torch.tensor([[[[1, 1.75, 2.5, 3.25, 4.], [0, 0.75, 1.5, 2.25, 3.],
                           [0.5, 0.25, 1., 1.75, 2.5], [0.75, 1.5, 2.25, 3., 3.75]],
                          [[0.5, 1.25, 2., 2.75, 3.5], [0.5, 0.25, 1., 1.75, 2.5],
                           [2., 1.25, 0.5, 0.25, 1.], [0, 0.75, 1.5, 2.25, 3.]]]])
    expt = torch.exp((-0.5 / 0.75**2) * expt**2)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    params = list(smear.parameters())
    assert len(params) == 2
    assert len(params[0]) == 5
    assert len(params[1]) == 5
    # centered = True
    smear = GaussianSmearing(start=1., stop=4., n_gaussians=5, trainable=True,
                             centered=True)
    expt = -0.5 / torch.tensor([1, 1.75, 2.5, 3.25, 4])**2
    expt = torch.exp(expt * dist[:, :, :, None]**2)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    params = list(smear.parameters())
    assert len(params) == 2
    assert len(params[0]) == 5
    assert len(params[1]) == 5
示例#2
0
def test_smear_gaussian_one_distance():
    # case of one distance
    dist = torch.tensor([[[1.0]]])
    # trainable = False
    smear = GaussianSmearing(n_gaussians=6, centered=False, trainable=False)
    expt = torch.exp(-0.5 * torch.tensor([[[1., 0., 1., 4., 9., 16.]]]))
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    assert list(smear.parameters()) == []
    # trainable = True
    smear = GaussianSmearing(n_gaussians=6, centered=False, trainable=True)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    params = list(smear.parameters())
    assert len(params) == 2
    assert len(params[0]) == 6
    assert len(params[1]) == 6
    # centered = True
    smear = GaussianSmearing(n_gaussians=6, centered=True)
    expt = -0.5 / torch.tensor([0., 1, 2, 3, 4, 5])**2
    expt = torch.exp(expt * dist[:, :, :, None]**2)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    assert list(smear.parameters()) == []
示例#3
0
def test_smear_gaussian():
    dist = torch.tensor([[[0.0, 1.0, 1.5], [0.5, 1.5, 3.0]]])
    # smear using 4 Gaussian functions with 1. spacing
    smear = GaussianSmearing(start=1., stop=4., n_gaussians=4)
    # absolute value of centered distances
    expt = torch.tensor([[[[1, 2, 3, 4], [0, 1, 2, 3], [0.5, 0.5, 1.5, 2.5]],
                          [[.5, 1.5, 2.5, 3.5], [.5, .5, 1.5, 2.5], [2, 1, 0, 1]]]])
    expt = torch.exp(-0.5 * expt**2)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    assert list(smear.parameters()) == []
    # centered = True
    smear = GaussianSmearing(start=1., stop=4., n_gaussians=4, centered=True)
    expt = torch.exp((-0.5 / torch.tensor([1, 2, 3, 4.])**2) * dist[:, :, :, None]**2)
    assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
    assert list(smear.parameters()) == []
示例#4
0
def test_gaussian_smearing(n_spatial_basis, distances):
    model = GaussianSmearing(n_gaussians=n_spatial_basis)
    out_shape = [*list(distances.shape), n_spatial_basis]
    inputs = [distances]
    assert_equal_shape(model, inputs, out_shape)
示例#5
0
    def __init__(self,
                 n_atom_basis=128,
                 n_filters=128,
                 n_interactions=3,
                 cutoff=5.0,
                 n_gaussians=25,
                 normalize_filter=False,
                 coupled_interactions=False,
                 return_intermediate=False,
                 max_z=100,
                 cutoff_network=HardCutoff,
                 trainable_gaussians=False,
                 distance_expansion=None,
                 charged_systems=False,
                 use_noise=False,
                 noise_mean=0,
                 noise_std=1,
                 chargeEmbedding=True,
                 ownFeatures=False,
                 nFeatures=8,
                 finalFeature=None,
                 finalFeatureStart=7,
                 finalFeatureStop=8):
        super(SchNet, self).__init__()

        self.finalFeature = finalFeature
        self.finalFeatureStart = finalFeatureStart
        self.finalFeatureStop = finalFeatureStop
        self.chargeEmbedding = chargeEmbedding
        self.ownFeatures = ownFeatures
        self.n_atom_basis = n_atom_basis

        # make a lookup table to store embeddings for each element (up to atomic
        # number max_z) each of which is a vector of size n_atom_basis
        if chargeEmbedding and not ownFeatures:
            self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0)
        elif chargeEmbedding and ownFeatures:
            if nFeatures is None:
                raise NotImplementedError
            self.embedding = nn.Embedding(max_z,
                                          int(n_atom_basis / 2),
                                          padding_idx=0)
            self.denseEmbedding = Dense(nFeatures, int(n_atom_basis / 2))
        elif ownFeatures and not chargeEmbedding:
            if nFeatures is None:
                raise NotImplementedError
            self.denseEmbedding = Dense(nFeatures, n_atom_basis)
        else:
            raise NotImplementedError

        # layer for computing interatomic distances
        self.distances = AtomDistances()

        # layer for expanding interatomic distances in a basis
        if distance_expansion is None:
            self.distance_expansion = GaussianSmearing(
                0.0, cutoff, n_gaussians, trainable=trainable_gaussians)
        else:
            self.distance_expansion = distance_expansion

        # block for computing interaction
        if isinstance(n_filters, list):
            self.interactions = nn.ModuleList([
                SchNetInteraction(
                    n_atom_basis=n_atom_basis,
                    n_spatial_basis=n_gaussians,
                    n_filters=n_filters[i],
                    cutoff_network=cutoff_network,
                    cutoff=cutoff,
                    normalize_filter=normalize_filter,
                ) for i in range(n_interactions)
            ])

        elif coupled_interactions:
            # use the same SchNetInteraction instance (hence the same weights)
            self.interactions = nn.ModuleList([
                SchNetInteraction(
                    n_atom_basis=n_atom_basis,
                    n_spatial_basis=n_gaussians,
                    n_filters=n_filters,
                    cutoff_network=cutoff_network,
                    cutoff=cutoff,
                    normalize_filter=normalize_filter,
                )
            ] * n_interactions)
        else:
            # use one SchNetInteraction instance for each interaction
            self.interactions = nn.ModuleList([
                SchNetInteraction(
                    n_atom_basis=n_atom_basis,
                    n_spatial_basis=n_gaussians,
                    n_filters=n_filters,
                    cutoff_network=cutoff_network,
                    cutoff=cutoff,
                    normalize_filter=normalize_filter,
                ) for _ in range(n_interactions)
            ])

        # set attributes
        self.use_noise = use_noise
        self.noise_mean = noise_mean
        self.noise_std = noise_std
        self.return_intermediate = return_intermediate
        self.charged_systems = charged_systems
        if charged_systems:
            self.charge = nn.Parameter(torch.Tensor(1, n_atom_basis))
            self.charge.data.normal_(0, 1.0 / n_atom_basis**0.5)
示例#6
0
    def __init__(
        self,
        n_atom_basis=128,
        n_filters=128,
        n_interactions=3,
        cutoff=5.0,
        n_gaussians=25,
        normalize_filter=False,
        coupled_interactions=False,
        return_intermediate=False,
        max_z=100,
        cutoff_network=CosineCutoff,
        trainable_gaussians=False,
        distance_expansion=None,
        charged_systems=False,
    ):
        super(SchNet, self).__init__()

        self.n_atom_basis = n_atom_basis
        # make a lookup table to store embeddings for each element (up to atomic
        # number max_z) each of which is a vector of size n_atom_basis
        self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0)

        # layer for computing interatomic distances
        self.distances = AtomDistances()

        # layer for expanding interatomic distances in a basis
        if distance_expansion is None:
            self.distance_expansion = GaussianSmearing(
                0.0, cutoff, n_gaussians, trainable=trainable_gaussians)
        else:
            self.distance_expansion = distance_expansion

        # block for computing interaction
        if coupled_interactions:
            # use the same SchNetInteraction instance (hence the same weights)
            self.interactions = nn.ModuleList([
                SchNetInteraction(
                    n_atom_basis=n_atom_basis,
                    n_spatial_basis=n_gaussians,
                    n_filters=n_filters,
                    cutoff_network=cutoff_network,
                    cutoff=cutoff,
                    normalize_filter=normalize_filter,
                )
            ] * n_interactions)
        else:
            # use one SchNetInteraction instance for each interaction
            self.interactions = nn.ModuleList([
                SchNetInteraction(
                    n_atom_basis=n_atom_basis,
                    n_spatial_basis=n_gaussians,
                    n_filters=n_filters,
                    cutoff_network=cutoff_network,
                    cutoff=cutoff,
                    normalize_filter=normalize_filter,
                ) for _ in range(n_interactions)
            ])

        # set attributes
        self.return_intermediate = return_intermediate
        self.charged_systems = charged_systems
        if charged_systems:
            self.charge = nn.Parameter(torch.Tensor(1, n_atom_basis))
            self.charge.data.normal_(0, 1.0 / n_atom_basis**0.5)