def batch_and_sum(dict_input, N, predict_keys, xyz): """Pooling function to get graph property. Separate the outputs back into batches, pool the results, compute gradient of scalar properties if "_grad" is in the key name. Args: dict_input (dict): Description N (list): number of batches predict_keys (list): Description xyz (tensor): xyz of the molecule Returns: dict: batched and pooled results """ results = dict() for key, val in dict_input.items(): # split if key in predict_keys and key + "_grad" not in predict_keys: results[key] = split_and_sum(val, N) elif key in predict_keys and key + "_grad" in predict_keys: results[key] = split_and_sum(val, N) grad = compute_grad(inputs=xyz, output=results[key]) results[key + "_grad"] = grad # For the case only predicting gradient elif key not in predict_keys and key + "_grad" in predict_keys: results[key] = split_and_sum(val, N) grad = compute_grad(inputs=xyz, output=results[key]) results[key + "_grad"] = grad return results
def forward(self, batch, xyz=None): """ Call the model Args: batch (dict): batch dictionary Returns: results (dict): dictionary of predictions """ out, xyz = self.atomwise(batch, xyz) N = batch["num_atoms"].detach().cpu().tolist() results = {} for key, val in out.items(): # split the outputs into those of each molecule split_val = torch.split(val, N) # sum the results for each molecule results[key] = torch.stack([i.sum() for i in split_val]) # compute gradients for key in self.grad_keys: output = results[key.replace("_grad", "")] grad = compute_grad(output=output, inputs=xyz) results[key] = grad return results
def forward(self, r, batch, xyz, take_grad=True): output = dict() # loop through output keys (e.g. energy_0 and energy_1) for output_key, top_set in self.auto_modules.items(): E = {key: 0.0 for key in list(self.terms.keys()) + ['total']} learned_params = {} # loop through associated topology nets (e.g. BondNet0 and AngletNet0 or # BondNet1 and AngletNet1) for top, top_net in top_set.items(): E[top] = top_net(r, batch, xyz) learned_params[top] = top_net.learned_params E['total'] += E[top] N = batch["num_atoms"].cpu().numpy().tolist() offset = torch.split(self.offset[output_key](r), N) offset = (torch.stack([torch.sum(item) for item in offset])).reshape(-1, 1) output[output_key] = E["total"] + offset if take_grad: grad = compute_grad(inputs=xyz, output=E["total"]) output[output_key + "_grad"] = grad return output
def forward(self, batch): result = {} num_bonds = batch["num_bonds"].tolist() xyz = batch['nxyz'][:, 1:4] xyz.requires_grad = True bond_list = batch["bonds"] r_0 = batch['bond_len'].squeeze() r = (xyz[bond_list[:, 0]] - xyz[bond_list[:, 1]]).pow(2).sum(-1).sqrt() if self.dif_bond_len: r = torch.stack([r for r in torch.split(r, num_bonds[0])]) e = self.k * (r - r_0).pow(2) if self.dif_bond_len: E = e.sum(1) else: E = torch.stack([e.sum(0) for e in torch.split(e, num_bonds[0])]) result['energy'] = E.sum().reshape(1, 1) result['energy_grad'] = compute_grad(inputs=xyz, output=E) return result
def forward(self, q, v): # use automatic differentiation to compute Virial # u = self.model(q) #f = -compute_grad(inputs=q, output=u.sum(-1)) nbr, dis, offsets = self.model._reset_topology(q) cell = self.model.cell_diag cell.requires_grad = True dis = x[nbr[:, 0]] - x[nbr[:, 1]] - offsets[nbr[:, 0], nbr[:, 1]] * cell N_dof = self.mass.shape[0] * self.system.dim dis_norm = dis.pow(2).sum(-1).sqrt() u = pair.model(dis_norm).sum() # compute temperature p = v * self.mass[:, None] ke = 0.5 * (p.pow(2) / self.mass[:, None]).sum() Temperature = ke / (N_dof * 0.5) Pideal = self.system.get_number_of_atoms( ) * Temperature / self.system.get_volume() Pvirial = compute_grad(cell, u) * (1 / (cell[0] * cell[1])) Pressure = Pideal - Pvirial return Pressure
def forward(self, t, state): with torch.set_grad_enabled(True): v = state[0] q = state[1] p_v = state[2] if self.adjoint: q.requires_grad = True N = self.N_dof p = v * self.mass[:, None] sys_ke = 0.5 * (p.pow(2) / self.mass[:, None]).sum() u = self.model(q) f = -compute_grad(inputs=q, output=u.sum(-1)) coupled_forces = (p_v[0] * p.reshape(-1) / self.Q[0]).reshape( -1, 3) dvdt = f - coupled_forces dpvdt_0 = 2 * (sys_ke - self.T * self.N_dof * 0.5) - p_v[0] * p_v[1] / self.Q[1] dpvdt_mid = (p_v[:-2].pow(2) / self.Q[:-2] - self.T) - p_v[2:] * p_v[1:-1] / self.Q[2:] dpvdt_last = p_v[-2].pow(2) / self.Q[-2] - self.T return (dvdt, v, torch.cat( (dpvdt_0[None], dpvdt_mid, dpvdt_last[None])))
def forward(self, t, state): # pq are the canonical momentum and position variables with torch.set_grad_enabled(True): v = state[0] q = state[1] if self.adjoint: q.requires_grad = True p = v * self.mass[:, None] u = self.model(q) f = -compute_grad(inputs=q, output=u.sum(-1)) dvdt = f return (dvdt, v)
def forward(self, batch): result = {} num_bonds = batch["num_bonds"].tolist() xyz = batch["nxyz"][:, 1:4] xyz.requires_grad = True bond_list = batch["bonds"] r_0 = batch["bond_len"].squeeze() r = (xyz[bond_list[:, 0]] - xyz[bond_list[:, 1]]).pow(2).sum(-1).sqrt() e = self.k * (r - r_0).pow(2) E = torch.stack( [e.sum(0).reshape(1) for e in torch.split(e, num_bonds)]) result["energy"] = E result["energy_grad"] = compute_grad(inputs=xyz, output=E) return result
def add_grad(self, batch, results, xyz): """ Add any required gradients of the predictions. Args: batch (dict): dictionary of props results (dict): dictionary of predicted values xyz (torch.tensor): (optional) coordinates Returns: results (dict): results updated with any gradients requested. """ batch_keys = batch.keys() # names of the gradients of each property result_grad_keys = [key + "_grad" for key in results.keys()] for key in batch_keys: # if the batch with the ground truth contains one of # these keys, then compute its predicted value if key in result_grad_keys: base_result = results[key.replace("_grad", "")] results[key] = compute_grad(inputs=xyz, output=base_result) return results