def schnet_batched_hessians(batch, model, device=0, energy_keys=["energy"]):

    stack_xyz, xyz, batch = pad(batch)
    r, N, xyz = model.convolve(batch, xyz)
    r = model.atomwisereadout(r)
    results = batch_and_sum(r, N, list(batch.keys()), xyz)
    hess_dic = {}
    N = batch["real_num_atoms"]

    for key in energy_keys:
        output = results[key]
        hess = hess_from_pad(stacked=stack_xyz,
                             output=output,
                             device=device,
                             N=N)
        hess_dic[key + "_hess"] = hess

    # change these keys back to their original values

    batch.pop("nbr_list")
    batch.pop("nxyz")
    batch.pop("num_atoms")

    batch["nbr_list"] = batch["real_nbrs"]
    batch["nxyz"] = batch["real_nxyz"]
    batch["num_atoms"] = batch["real_num_atoms"]

    return hess_dic
Exemple #2
0
    def forward(self, batch, xyz=None, **kwargs):
        """Summary

        Args:
            batch (dict): dictionary of props
            xyz (torch.tensor): (optional) coordinates

        Returns:
            dict: dictionary of results

        """

        r, N, xyz = self.convolve(batch, xyz)
        r = self.atomwisereadout(r)
        results = batch_and_sum(r, N, list(batch.keys()), xyz)

        return results
Exemple #3
0
    def forward(self, batch):
        r = batch['nxyz'][:, 0]
        xyz = batch['nxyz'][:, 1:4]
        N = batch['num_atoms'].reshape(-1).tolist()
        a_mol = batch['atoms_nbr_list']
        a_sys = batch['nbr_list']
        temp = batch['temp']
        # offsets take care of periodic boundary conditions
        offsets = batch.get('offsets', 0)  # offsets only affect nbr_list
        xyz.requires_grad = True

        node_input = self.atom_embed(r.long()).squeeze()  #*temp

        # system convolution
        r_sys = self.SeqConv(node_input, xyz, a_sys, self.system_convolutions,
                             offsets)
        r_mol = self.SeqConv(node_input, xyz, a_mol,
                             self.molecule_convolutions)
        # Excluded Volume interactions

        temp_learn = self.temp_scale(temp)
        r_ex = self.V_ex(xyz, a_sys, offsets)

        if self.temp_type == 'mult':
            results = self.atomwisereadout((r_sys * temp_learn.T) +
                                           (r_mol * temp_learn.T))
        elif self.temp_type == 'div':
            results = self.atomwisereadout((r_sys / temp_learn.T) +
                                           (r_mol / temp_learn.T))
        elif self.temp_type == 'sum':
            results = self.atomwisereadout((r_sys + temp_learn.T) +
                                           (r_mol + temp_learn.T))
        elif self.temp_type == 'sub':
            results = self.atomwisereadout((r_sys - temp_learn.T) +
                                           (r_mol - temp_learn.T))

        # add excluded volume interactions
        results['energy'] += r_ex
        results = batch_and_sum(results, N, list(batch.keys()), xyz)
        return results
Exemple #4
0
    def forward(self, batch, **kwargs):
        """Summary
        
        Args:
            batch (dict): dictionary of props
        
        Returns:
            dict: dionary of results 
        """
        r = batch['nxyz'][:, 0]
        xyz = batch['nxyz'][:, 1:4]
        N = batch['num_atoms'].reshape(-1).tolist()
        a = batch['nbr_list']
        aggr_wgt = batch['aggr_wgt']

        # offsets take care of periodic boundary conditions
        offsets = batch.get('offsets', 0)

        xyz.requires_grad = True

        # calculating the distances
        e = (xyz[a[:, 0]] - xyz[a[:, 1]] + offsets).pow(2).sum(1).sqrt()[:,
                                                                         None]

        # ensuring image atoms have the same vectors of their corresponding
        # atom inside the unit cell
        r = self.atom_embed(r.long()).squeeze()

        # update function includes periodic boundary conditions
        for i, conv in enumerate(self.convolutions):
            dr = conv(r=r, e=e, a=a, aggr_wgt=aggr_wgt)
            r = r + dr

        r = self.atomwisereadout(r)
        results = batch_and_sum(r, N, list(batch.keys()), xyz)

        return results
    def forward(self, batch, **kwargs):
        r = batch['nxyz'][:, 0]
        xyz = batch['nxyz'][:, 1:4]
        N = batch['num_atoms'].reshape(-1).tolist()
        a_mol = batch['atoms_nbr_list']
        a_sys = batch['nbr_list']

        # offsets take care of periodic boundary conditions
        offsets = batch.get('offsets', 0)  # offsets only affect nbr_list
        xyz.requires_grad = True
        node_input = self.atom_embed(r.long()).squeeze()

        # system convolution
        r_sys = self.SeqConv(node_input, xyz, a_sys, self.system_convolutions,
                             offsets)
        r_mol = self.SeqConv(node_input, xyz, a_mol,
                             self.molecule_convolutions)
        # Excluded Volume interactions
        #r_ex = self.V_ex(xyz, a_sys, offsets)
        results = self.atomwisereadout(r_sys + r_mol)
        # add excluded volume interactions
        #results['energy'] += r_ex
        results = batch_and_sum(results, N, list(batch.keys()), xyz)
        return results