예제 #1
0
def batch_and_sum(dict_input, N, predict_keys, xyz):
    """Pooling function to get graph property.
        Separate the outputs back into batches, pool the results,
        compute gradient of scalar properties if "_grad" is in the key name.

    Args:
        dict_input (dict): Description
        N (list): number of batches
        predict_keys (list): Description
        xyz (tensor): xyz of the molecule

    Returns:
        dict: batched and pooled results
    """

    results = dict()

    for key, val in dict_input.items():
        # split
        if key in predict_keys and key + "_grad" not in predict_keys:
            results[key] = split_and_sum(val, N)
        elif key in predict_keys and key + "_grad" in predict_keys:
            results[key] = split_and_sum(val, N)
            grad = compute_grad(inputs=xyz, output=results[key])
            results[key + "_grad"] = grad
        # For the case only predicting gradient
        elif key not in predict_keys and key + "_grad" in predict_keys:
            results[key] = split_and_sum(val, N)
            grad = compute_grad(inputs=xyz, output=results[key])
            results[key + "_grad"] = grad

    return results
예제 #2
0
    def forward(self, batch, xyz=None):
        """
        Call the model
        Args:
            batch (dict): batch dictionary
        Returns:
            results (dict): dictionary of predictions
        """

        out, xyz = self.atomwise(batch, xyz)
        N = batch["num_atoms"].detach().cpu().tolist()
        results = {}

        for key, val in out.items():
            # split the outputs into those of each molecule
            split_val = torch.split(val, N)
            # sum the results for each molecule
            results[key] = torch.stack([i.sum() for i in split_val])

        # compute gradients

        for key in self.grad_keys:
            output = results[key.replace("_grad", "")]
            grad = compute_grad(output=output, inputs=xyz)
            results[key] = grad

        return results
예제 #3
0
파일: modules.py 프로젝트: torchmd/mdgrad
    def forward(self, r, batch, xyz, take_grad=True):

        output = dict()

        # loop through output keys (e.g. energy_0 and energy_1)
        for output_key, top_set in self.auto_modules.items():
            E = {key: 0.0 for key in list(self.terms.keys()) + ['total']}
            learned_params = {}
            # loop through associated topology nets (e.g. BondNet0 and AngletNet0 or
            # BondNet1 and AngletNet1)
            for top, top_net in top_set.items():
                E[top] = top_net(r, batch, xyz)
                learned_params[top] = top_net.learned_params
                E['total'] += E[top]

            N = batch["num_atoms"].cpu().numpy().tolist()
            offset = torch.split(self.offset[output_key](r), N)
            offset = (torch.stack([torch.sum(item) for item in offset])).reshape(-1, 1)

            output[output_key] = E["total"] + offset

            if take_grad:
                grad = compute_grad(inputs=xyz, output=E["total"])
                output[output_key + "_grad"] = grad

        return output
예제 #4
0
파일: modules.py 프로젝트: yingli2009/T-NFF
    def forward(self, batch):

        result = {}

        num_bonds = batch["num_bonds"].tolist()

        xyz = batch['nxyz'][:, 1:4]
        xyz.requires_grad = True
        bond_list = batch["bonds"]

        r_0 = batch['bond_len'].squeeze()

        r = (xyz[bond_list[:, 0]] - xyz[bond_list[:, 1]]).pow(2).sum(-1).sqrt()

        if self.dif_bond_len:
            r = torch.stack([r for r in torch.split(r, num_bonds[0])])

        e = self.k * (r - r_0).pow(2)

        if self.dif_bond_len:
            E = e.sum(1)
        else:
            E = torch.stack([e.sum(0) for e in torch.split(e, num_bonds[0])])

        result['energy'] = E.sum().reshape(1, 1)
        result['energy_grad'] = compute_grad(inputs=xyz, output=E)

        return result
예제 #5
0
    def forward(self, q, v):

        # use automatic differentiation to compute Virial
        # u = self.model(q)
        #f = -compute_grad(inputs=q, output=u.sum(-1))

        nbr, dis, offsets = self.model._reset_topology(q)
        cell = self.model.cell_diag
        cell.requires_grad = True
        dis = x[nbr[:, 0]] - x[nbr[:,
                                   1]] - offsets[nbr[:, 0], nbr[:, 1]] * cell

        N_dof = self.mass.shape[0] * self.system.dim

        dis_norm = dis.pow(2).sum(-1).sqrt()
        u = pair.model(dis_norm).sum()

        # compute temperature
        p = v * self.mass[:, None]
        ke = 0.5 * (p.pow(2) / self.mass[:, None]).sum()
        Temperature = ke / (N_dof * 0.5)

        Pideal = self.system.get_number_of_atoms(
        ) * Temperature / self.system.get_volume()

        Pvirial = compute_grad(cell, u) * (1 / (cell[0] * cell[1]))

        Pressure = Pideal - Pvirial

        return Pressure
예제 #6
0
    def forward(self, t, state):
        with torch.set_grad_enabled(True):

            v = state[0]
            q = state[1]
            p_v = state[2]

            if self.adjoint:
                q.requires_grad = True

            N = self.N_dof
            p = v * self.mass[:, None]

            sys_ke = 0.5 * (p.pow(2) / self.mass[:, None]).sum()

            u = self.model(q)
            f = -compute_grad(inputs=q, output=u.sum(-1))

            coupled_forces = (p_v[0] * p.reshape(-1) / self.Q[0]).reshape(
                -1, 3)

            dvdt = f - coupled_forces

            dpvdt_0 = 2 * (sys_ke - self.T * self.N_dof *
                           0.5) - p_v[0] * p_v[1] / self.Q[1]
            dpvdt_mid = (p_v[:-2].pow(2) / self.Q[:-2] -
                         self.T) - p_v[2:] * p_v[1:-1] / self.Q[2:]
            dpvdt_last = p_v[-2].pow(2) / self.Q[-2] - self.T

        return (dvdt, v, torch.cat(
            (dpvdt_0[None], dpvdt_mid, dpvdt_last[None])))
예제 #7
0
    def forward(self, t, state):
        # pq are the canonical momentum and position variables
        with torch.set_grad_enabled(True):

            v = state[0]
            q = state[1]

            if self.adjoint:
                q.requires_grad = True

            p = v * self.mass[:, None]

            u = self.model(q)
            f = -compute_grad(inputs=q, output=u.sum(-1))
            dvdt = f

        return (dvdt, v)
예제 #8
0
    def forward(self, batch):

        result = {}

        num_bonds = batch["num_bonds"].tolist()

        xyz = batch["nxyz"][:, 1:4]
        xyz.requires_grad = True
        bond_list = batch["bonds"]

        r_0 = batch["bond_len"].squeeze()

        r = (xyz[bond_list[:, 0]] - xyz[bond_list[:, 1]]).pow(2).sum(-1).sqrt()

        e = self.k * (r - r_0).pow(2)

        E = torch.stack(
            [e.sum(0).reshape(1) for e in torch.split(e, num_bonds)])

        result["energy"] = E
        result["energy_grad"] = compute_grad(inputs=xyz, output=E)

        return result
예제 #9
0
    def add_grad(self, batch, results, xyz):
        """
        Add any required gradients of the predictions.
        Args:
            batch (dict): dictionary of props
            results (dict): dictionary of predicted values
            xyz (torch.tensor): (optional) coordinates
        Returns:
            results (dict): results updated with any gradients
                requested.
        """

        batch_keys = batch.keys()
        # names of the gradients of each property
        result_grad_keys = [key + "_grad" for key in results.keys()]
        for key in batch_keys:
            # if the batch with the ground truth contains one of
            # these keys, then compute its predicted value
            if key in result_grad_keys:
                base_result = results[key.replace("_grad", "")]
                results[key] = compute_grad(inputs=xyz,
                                            output=base_result)

        return results