def make_model(params, device, world_size, weight_path):
    """
    Create a model and wrap it in DistributedDataParallel.
    Args:
        params (dict):   training/network parameters
        device (int): local rank of the current gpu
        world_size (int): total number number of gpus altogether

    Returns:
        model (nn.parallel.DistributedDataParallel): wrapped
            moddel.

    """

    start_model_path = os.path.join(weight_path, "start_model")

    if os.path.isfile(start_model_path):
        model = torch.load(start_model_path, map_location="cpu")
        print(f"Loading model from {start_model_path}")
        sys.stdout.flush()
    else:
        model = get_model(params=params,
                          model_type=params.get("model_type", "SchNet"))

    if device != "cpu":
        torch.cuda.set_device(device)
    model.to(device)

    torch_par = params.get("torch_par", True)
    if torch_par:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[device])
    return model
Beispiel #2
0
    def testDynamics(self):

        dataset = Dataset.from_file('../../examples/dataset.pth.tar')
        props = dataset[0]
        atoms = ase.AtomsBatch(positions=props['nxyz'][:, 1:],
                               numbers=props['nxyz'][:, 0],
                               props=props)

        # initialize models
        params = {
            'n_atom_basis': 64,
            'n_filters': 64,
            'n_gaussians': 32,
            'n_convolutions': 2,
            'cutoff': 5.0,
            'trainable_gauss': True
        }

        model = get_model(params)

        nff_ase = NeuralFF(model=model, device='cuda:1')
        atoms.set_calculator(nff_ase)

        nve = Dynamics(atoms, DEFAULTNVEPARAMS)
        nve.run()
Beispiel #3
0
def get_gnn_potential(assignments,  sys_params):
    lj_params = {'epsilon': assignments['epsilon'], 
                 'sigma': assignments['sigma'], 
                 'power': 12}

    gnn_params = {
        'n_atom_basis': width_dict[assignments['n_atom_basis']],
        'n_filters': width_dict[assignments['n_filters']],
        'n_gaussians': int(assignments['cutoff']//assignments['gaussian_width']),
        'n_convolutions': assignments['n_convolutions'],
        'cutoff': assignments['cutoff'],
        'trainable_gauss': False
    }

    net = get_model(gnn_params)
    prior = ExcludedVolume(**lj_params)
    return net, prior
Beispiel #4
0
                 'power': 12}

    size = 4
    L = 19.73 / size

    device = 'cpu'
    atoms = FaceCenteredCubic(directions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
                              symbol='H',
                              size=(size, size, size),
                              latticeconstant= L,
                              pbc=True)

    system = System(atoms, device=device)

    pair = PairPot(ExcludedVolume, lj_params,
                    cell=torch.Tensor(system.get_cell_len()), 
                    device=device,
                    cutoff=L/2,
                    ).to(device)

    model = get_model(params)
    PES = GNNPotentials(model, system.get_batch(), system.get_cell_len(), cutoff=5.0, device=system.device)

    system.set_temperature(298.0)


    # Todo test Pair pot with fixed atom index  



Beispiel #5
0
def train(params, suggestion_id, project_name, device, n_epochs):

    print(params)

    model_path = '{}/{}'.format(project_name, suggestion_id)
    os.makedirs(model_path)

    n_atoms = 50

    xyz = torch.Tensor(gen_helix(15, n_atoms))[None, ...]

    dihe1_index = [[i, i + 1, i + 2, i + 3] for i in range(n_atoms)
                   if max([i, i + 1, i + 2, i + 3]) <= n_atoms - 1]
    dihe1_top = torch.LongTensor(dihe1_index)
    dihe2_index = [[i, i + 2, i + 4, i + 6] for i in range(n_atoms)
                   if max([i, i + 2, i + 4, i + 6]) <= n_atoms - 1]
    dihe2_top = torch.LongTensor(dihe2_index)

    angle1_index = [[i, i + 1, i + 2] for i in range(n_atoms)
                    if max([i, i + 1, i + 2]) <= n_atoms - 1]
    angle1_top = torch.LongTensor(angle1_index)
    angle2_index = [[i, i + 2, i + 4] for i in range(n_atoms)
                    if max([i, i + 2, i + 4]) <= n_atoms - 1]
    angle2_top = torch.LongTensor(angle2_index)

    bond_index = [[i, i + 1] for i in range(n_atoms)
                  if max([i, i + 1]) <= n_atoms - 1]
    bond_top = torch.LongTensor(bond_index)
    bond13 = [[i, i + 2] for i in range(n_atoms)
              if max([i, i + 2]) <= n_atoms - 1]
    bond13_top = torch.LongTensor(bond13)
    bond14 = [[i, i + 3] for i in range(n_atoms)
              if max([i, i + 3]) <= n_atoms - 1]
    bond14_top = torch.LongTensor(bond14)
    bond15 = [[i, i + 4] for i in range(n_atoms)
              if max([i, i + 4]) <= n_atoms - 1]
    bond15_top = torch.LongTensor(bond15)
    bond16 = [[i, i + 5] for i in range(n_atoms)
              if max([i, i + 5]) <= n_atoms - 1]
    bond16_top = torch.LongTensor(bond16)
    bond17 = [[i, i + 6] for i in range(n_atoms)
              if max([i, i + 6]) <= n_atoms - 1]
    bond17_top = torch.LongTensor(bond17)
    bond18 = [[i, i + 7] for i in range(n_atoms)
              if max([i, i + 7]) <= n_atoms - 1]
    bond18_top = torch.LongTensor(bond18)

    targ_dihe1 = compute_dihe(xyz, dihe1_top)
    targ_angle1 = compute_angle(xyz, angle1_top)
    targ_dihe2 = compute_dihe(xyz, dihe2_top)
    targ_angle2 = compute_angle(xyz, angle2_top)

    end2end = torch.LongTensor([[0, 49]])
    dis_end2end_targ = compute_bond(xyz, end2end)

    targ_bond = compute_bond(xyz, bond_top)
    targ_bond13 = compute_bond(xyz, bond13_top)
    targ_bond14 = compute_bond(xyz, bond14_top)
    targ_bond15 = compute_bond(xyz, bond15_top)
    targ_bond16 = compute_bond(xyz, bond16_top)
    targ_bond17 = compute_bond(xyz, bond17_top)
    targ_bond18 = compute_bond(xyz, bond18_top)

    def get_dis_list(xyz, cutoff=5.0):

        n_atoms = xyz.shape[1]
        adj = torch.ones(n_atoms, n_atoms)

        atom_idx = torch.LongTensor([[i, i] for i in range(n_atoms)])
        adj[atom_idx[:, 0], atom_idx[:, 1]] = 0.0

        adj = adj.nonzero(as_tuple=False)
        adj = adj[(compute_bond(xyz, adj).squeeze() < cutoff), :]
        targ_dis = compute_bond(xyz, adj)

        return targ_dis, adj

    dis_targ, adj = get_dis_list(xyz, 5.0)
    b_targ, a_targ, d_targ = compute_intcoord(xyz)

    bond_len = targ_bond[0, 0].item()

    # define system objects
    chain = Atoms(numbers=[1.] * n_atoms,
                  positions=[
                      np.array([50., 50., 50.]) +
                      np.array([bond_len, 0., 0.]) * i for i in range(n_atoms)
                  ],
                  cell=[100.0, 100.0, 100.0])

    from torchmd.system import System
    from torchmd.potentials import LennardJones, ExcludedVolume
    from ase import units

    system = System(chain, device=device)
    system.set_temperature(params['T'])

    from torchmd.interface import BondPotentials, GNNPotentials, Stack, PairPotentials
    bondenergy = BondPotentials(system, bond_top, params['k0'], bond_len)

    from nff.train import get_model

    gnnparams = {
        'n_atom_basis': params['n_atom_basis'],
        'n_filters': params['n_filters'],
        'n_gaussians': params['n_gaussians'],
        'n_convolutions': params['n_convolutions'],
        'cutoff': params['cutoff']
    }

    schnet = get_model(gnnparams)

    GNN = GNNPotentials(
        system,
        schnet,
        cutoff=gnnparams['cutoff'],
    )

    pair = PairPotentials(
        system,
        ExcludedVolume(**{
            'epsilon': params['epsilon'],
            'sigma': params['sigma'],
            'power': 10
        }),
        cutoff=2.5,
        ex_pairs=bond_top).to(system.device)

    FF = Stack({'gnn': GNN, 'prior': bondenergy, 'pair': pair})

    from torchmd.md import NoseHooverChain, Simulations

    diffeq = NoseHooverChain(FF,
                             system,
                             Q=50.0,
                             T=params['T'],
                             num_chains=5,
                             adjoint=True).to(device)

    tau = params['tau']
    sim = Simulations(system, diffeq, wrap=False, method=params['method'])

    optimizer = torch.optim.Adam(list(diffeq.parameters()), lr=params['lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           min_lr=5e-5,
                                                           verbose=True,
                                                           factor=0.5,
                                                           patience=20,
                                                           threshold=5e-5)

    loss_log = []

    for i in range(0, n_epochs):
        trajs = sim.simulate(steps=tau, frequency=int(tau), dt=params['dt'])
        v_t, q_t, pv_t = trajs

        if torch.isnan(q_t.reshape(-1)).sum().item() > 0:
            return 55.0

        angle1 = compute_angle(q_t, angle1_top.to(device))
        dihe1 = compute_dihe(q_t, dihe1_top.to(device))
        angle2 = compute_angle(q_t, angle2_top.to(device))
        dihe2 = compute_dihe(q_t, dihe2_top.to(device))

        bonds = compute_bond(q_t, bond_top.to(device))
        bonds13 = compute_bond(q_t, bond13_top.to(device))
        bonds14 = compute_bond(q_t, bond14_top.to(device))
        bonds15 = compute_bond(q_t, bond15_top.to(device))
        bonds16 = compute_bond(q_t, bond16_top.to(device))
        bonds17 = compute_bond(q_t, bond17_top.to(device))
        bonds18 = compute_bond(q_t, bond18_top.to(device))

        dis_end2end = compute_bond(q_t, end2end.to(device))

        if i > 0:
            # if params['lastframe'] == 'True':
            #     traj_train = q_t[[-1]]
            # else:
            traj_train = q_t

            #b, a, d = compute_intcoord(traj_train)
            # dis = compute_bond(traj_train, adj.to(device))

            # loss_b = (b - b_targ.to(device).squeeze()).pow(2).mean()
            # loss_a = (a - a_targ.to(device).squeeze()).pow(2).mean()
            # loss_d = (d - d_targ.to(device).squeeze()).pow(2).mean()
            loss_end2end = (
                dis_end2end -
                dis_end2end_targ.to(device).squeeze()).pow(2).mean()

            # dis_diff = dis - dis_targ.to(dis.device)
            # focus = (dis_diff.abs() * (1/params['focus_temp'])).softmax(-1)
            # #print(dis.mean().item())
            # loss_dis = (focus * dis_diff.pow(2)).mean()

            # loss = params['l_b'] * loss_b + \
            #         params['l_a'] * loss_a + \
            #          params['l_d'] * loss_d + \
            #          params['l_dis'] * loss_dis + \
            #           params['l_end2end'] * loss_end2end

            #loss_record = loss_b + loss_a + loss_d + dis_diff.pow(2).mean()

            #print(loss_b, loss_a, loss_d, dis_diff.pow(2).mean())

            loss_bond = (bonds - targ_bond.to(device).squeeze()).pow(2).mean()
            loss_angle1 = (angle1 -
                           targ_angle1.to(device).squeeze()).pow(2).mean()
            loss_dihe1 = (dihe1 -
                          targ_dihe1.to(device).squeeze()).pow(2).mean()

            loss_bond13 = (bonds13 -
                           targ_bond13.to(device).squeeze()).pow(2).mean()
            loss_bond14 = (bonds14 -
                           targ_bond14.to(device).squeeze()).pow(2).mean()
            loss_bond15 = (bonds15 -
                           targ_bond15.to(device).squeeze()).pow(2).mean()
            loss_bond16 = (bonds16 -
                           targ_bond16.to(device).squeeze()).pow(2).mean()


            loss =  params['l_bond'] *  loss_bond + \
                    params['l_dihe1'] * loss_dihe1 + \
                    params['l_angle1'] * loss_angle1 + \
                    params['l_bond13'] * loss_bond13 + \
                    params['l_bond14'] * loss_bond14 + \
                    params['l_bond15'] * loss_bond15 + \
                    params['l_bond16'] * loss_bond16 + \
                    params['l_end2end'] * loss_end2end  # + \

            loss_record = loss_angle1 + \
                            loss_bond + \
                            loss_dihe1 + \
                            loss_bond13 + \
                            loss_bond14 + \
                            loss_bond15 + \
                            loss_bond16 + \
                            loss_end2end

            loss.backward()
            # duration = (datetime.now() - current_time)
            optimizer.step()
            optimizer.zero_grad()

            scheduler.step(loss)

            print(loss.item())
            if math.isnan(loss_record.item()):
                return 55.0

            loss_log.append(loss_record.item())

    from utils import to_mdtraj
    traj = to_mdtraj(system, sim.log)
    traj.center_coordinates()
    traj.save_xyz("{}/train.xyz".format(model_path))

    np.savetxt("{}/loss.csv".format(model_path), np.array(loss_log))

    return np.array(loss_log[-10:]).mean()