def schnet_batched_hessians(batch, model, device=0, energy_keys=["energy"]): stack_xyz, xyz, batch = pad(batch) r, N, xyz = model.convolve(batch, xyz) r = model.atomwisereadout(r) results = batch_and_sum(r, N, list(batch.keys()), xyz) hess_dic = {} N = batch["real_num_atoms"] for key in energy_keys: output = results[key] hess = hess_from_pad(stacked=stack_xyz, output=output, device=device, N=N) hess_dic[key + "_hess"] = hess # change these keys back to their original values batch.pop("nbr_list") batch.pop("nxyz") batch.pop("num_atoms") batch["nbr_list"] = batch["real_nbrs"] batch["nxyz"] = batch["real_nxyz"] batch["num_atoms"] = batch["real_num_atoms"] return hess_dic
def forward(self, batch, xyz=None, **kwargs): """Summary Args: batch (dict): dictionary of props xyz (torch.tensor): (optional) coordinates Returns: dict: dictionary of results """ r, N, xyz = self.convolve(batch, xyz) r = self.atomwisereadout(r) results = batch_and_sum(r, N, list(batch.keys()), xyz) return results
def forward(self, batch): r = batch['nxyz'][:, 0] xyz = batch['nxyz'][:, 1:4] N = batch['num_atoms'].reshape(-1).tolist() a_mol = batch['atoms_nbr_list'] a_sys = batch['nbr_list'] temp = batch['temp'] # offsets take care of periodic boundary conditions offsets = batch.get('offsets', 0) # offsets only affect nbr_list xyz.requires_grad = True node_input = self.atom_embed(r.long()).squeeze() #*temp # system convolution r_sys = self.SeqConv(node_input, xyz, a_sys, self.system_convolutions, offsets) r_mol = self.SeqConv(node_input, xyz, a_mol, self.molecule_convolutions) # Excluded Volume interactions temp_learn = self.temp_scale(temp) r_ex = self.V_ex(xyz, a_sys, offsets) if self.temp_type == 'mult': results = self.atomwisereadout((r_sys * temp_learn.T) + (r_mol * temp_learn.T)) elif self.temp_type == 'div': results = self.atomwisereadout((r_sys / temp_learn.T) + (r_mol / temp_learn.T)) elif self.temp_type == 'sum': results = self.atomwisereadout((r_sys + temp_learn.T) + (r_mol + temp_learn.T)) elif self.temp_type == 'sub': results = self.atomwisereadout((r_sys - temp_learn.T) + (r_mol - temp_learn.T)) # add excluded volume interactions results['energy'] += r_ex results = batch_and_sum(results, N, list(batch.keys()), xyz) return results
def forward(self, batch, **kwargs): """Summary Args: batch (dict): dictionary of props Returns: dict: dionary of results """ r = batch['nxyz'][:, 0] xyz = batch['nxyz'][:, 1:4] N = batch['num_atoms'].reshape(-1).tolist() a = batch['nbr_list'] aggr_wgt = batch['aggr_wgt'] # offsets take care of periodic boundary conditions offsets = batch.get('offsets', 0) xyz.requires_grad = True # calculating the distances e = (xyz[a[:, 0]] - xyz[a[:, 1]] + offsets).pow(2).sum(1).sqrt()[:, None] # ensuring image atoms have the same vectors of their corresponding # atom inside the unit cell r = self.atom_embed(r.long()).squeeze() # update function includes periodic boundary conditions for i, conv in enumerate(self.convolutions): dr = conv(r=r, e=e, a=a, aggr_wgt=aggr_wgt) r = r + dr r = self.atomwisereadout(r) results = batch_and_sum(r, N, list(batch.keys()), xyz) return results
def forward(self, batch, **kwargs): r = batch['nxyz'][:, 0] xyz = batch['nxyz'][:, 1:4] N = batch['num_atoms'].reshape(-1).tolist() a_mol = batch['atoms_nbr_list'] a_sys = batch['nbr_list'] # offsets take care of periodic boundary conditions offsets = batch.get('offsets', 0) # offsets only affect nbr_list xyz.requires_grad = True node_input = self.atom_embed(r.long()).squeeze() # system convolution r_sys = self.SeqConv(node_input, xyz, a_sys, self.system_convolutions, offsets) r_mol = self.SeqConv(node_input, xyz, a_mol, self.molecule_convolutions) # Excluded Volume interactions #r_ex = self.V_ex(xyz, a_sys, offsets) results = self.atomwisereadout(r_sys + r_mol) # add excluded volume interactions #results['energy'] += r_ex results = batch_and_sum(results, N, list(batch.keys()), xyz) return results