Пример #1
0
 def get_energies_cg(self, atom_grid, energy_ndx):
     coords = avg_blob(
         atom_grid,
         res=self.cfg.getint('grid', 'resolution'),
         width=self.cfg.getfloat('grid', 'length'),
         sigma=self.cfg.getfloat('grid', 'sigma_cg'),
         device=self.device,
     )
     #print(coords)
     bond_ndx, angle_ndx, dih_ndx, lj_intra_ndx, lj_ndx = energy_ndx
     if bond_ndx.size()[1]:
         b_energy = self.energy_cg.bond(coords, bond_ndx)
     else:
         b_energy = torch.zeros([], dtype=torch.float32, device=self.device)
     if angle_ndx.size()[1]:
         a_energy = self.energy_cg.angle(coords, angle_ndx)
     else:
         a_energy = torch.zeros([], dtype=torch.float32, device=self.device)
     if dih_ndx.size()[1]:
         d_energy = self.energy_cg.dih(coords, dih_ndx)
     else:
         d_energy = torch.zeros([], dtype=torch.float32, device=self.device)
     if lj_ndx.size()[1]:
         l_energy = self.energy_cg.lj(coords, lj_intra_ndx)
     else:
         l_energy = torch.zeros([], dtype=torch.float32, device=self.device)
     return torch.mean(b_energy), torch.mean(a_energy), torch.mean(
         d_energy), torch.mean(l_energy)
Пример #2
0
    def get_energies_out(self, atom_grid, coords_inter, energy_ndx):
        coords = avg_blob(
            atom_grid,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma_out'),
            device=self.device,
        )

        bond_ndx, angle_ndx, dih_ndx, lj_ndx = energy_ndx
        if bond_ndx.size()[1]:
            b_energy = self.energy_out.bond(coords, bond_ndx)
        else:
            b_energy = torch.zeros([], dtype=torch.float32, device=self.device)
        if angle_ndx.size()[1]:
            a_energy = self.energy_out.angle(coords, angle_ndx)
        else:
            a_energy = torch.zeros([], dtype=torch.float32, device=self.device)
        if dih_ndx.size()[1]:
            d_energy = self.energy_out.dih(coords, dih_ndx)
        else:
            d_energy = torch.zeros([], dtype=torch.float32, device=self.device)
        if self.out_env and self.n_env_mols:
            coords = torch.cat((coords, coords_inter), 1)
            l_energy = self.energy_out.lj(coords, lj_ndx)
        elif lj_ndx.size()[1]:
            l_energy = self.energy_out.lj(coords, lj_ndx)
        else:
            l_energy = torch.zeros([], dtype=torch.float32, device=self.device)
        return torch.mean(b_energy), torch.mean(a_energy), torch.mean(
            d_energy), torch.mean(l_energy)
Пример #3
0
    def dstr(self, atom_grid, energy_ndx):
        coords = avg_blob(
            atom_grid,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma_out'),
            device=self.device,
        )
        bond_ndx, angle_ndx, dih_ndx, lj_ndx = energy_ndx

        if bond_ndx.size()[1] and bond_ndx.size()[1]:
            b_dstr = self.gauss_hist_bond(coords, bond_ndx)
        else:
            b_dstr = torch.zeros([self.n_bins,1], dtype=torch.float32, device=self.device)
        if angle_ndx.size()[1] and angle_ndx.size()[1]:
            a_dstr = self.gauss_hist_angle(coords, angle_ndx)
        else:
            a_dstr = torch.zeros([self.n_bins,1], dtype=torch.float32, device=self.device)
        if dih_ndx.size()[1] and dih_ndx.size()[1]:
            d_dstr = self.gauss_hist_bond(coords, dih_ndx)
        else:
            d_dstr = torch.zeros([self.n_bins,1], dtype=torch.float32, device=self.device)
        if lj_ndx.size()[1] and lj_ndx.size()[1]:
            nb_dstr = self.gauss_hist_nb(coords, lj_ndx)
        else:
            nb_dstr = torch.zeros([self.n_bins,1], dtype=torch.float32, device=self.device)
        return b_dstr, a_dstr, d_dstr, nb_dstr
Пример #4
0
 def get_energies_from_grid(self, atom_grid, energy_ndx):
     coords = avg_blob(
         atom_grid,
         res=self.cfg.getint('grid', 'resolution'),
         width=self.cfg.getfloat('grid', 'length'),
         sigma=self.cfg.getfloat('grid', 'sigma'),
         device=self.device,
     )
     bond_ndx, angle_ndx, dih_ndx, lj_ndx = energy_ndx
     b_energy = self.energy.bond(coords, bond_ndx)
     a_energy = self.energy.angle(coords, angle_ndx)
     d_energy = self.energy.dih(coords, dih_ndx)
     l_energy = self.energy.lj(coords, lj_ndx)
     return b_energy, a_energy, d_energy, l_energy
Пример #5
0
    def featurize(self, grid, bond_ndx, angle_ndx1, angle_ndx2, dih_ndx,
                  lj_ndx):
        #bond_ndx, angle_ndx, dih_ndx, lj_ndx = energy_ndx
        coords = avg_blob(
            grid,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma'),
            device=self.device,
        )

        bond_grid = self.energy.bond_grid(self.grid, coords, bond_ndx)
        angle_grid1 = self.energy.angle_grid1(self.grid, coords, angle_ndx1)
        angle_grid2 = self.energy.angle_grid1(self.grid, coords, angle_ndx2)
        angle_grid = self.energy.energy_to_prop(angle_grid1 + angle_grid2)

        #dih_grid = self.energy.dih_grid(self.grid, coords, dih_ndx)
        lj_grid = self.energy.lj_grid(self.grid, coords, lj_ndx)

        feature_grid = torch.stack([bond_grid, angle_grid, lj_grid], 1)

        return feature_grid
Пример #6
0
    def predict(self, elems, initial, energy_ndx):

        aa_grid, cg_features = initial

        generated_atoms = []
        for target_type, aa_featvec, repl in zip(*elems):
            fake_aa_features = self.featurize(aa_grid, aa_featvec)
            c_fake = fake_aa_features + cg_features
            target_type = target_type.repeat(self.bs, 1)
            z = torch.empty(
                [target_type.shape[0], self.z_dim],
                dtype=torch.float32,
                device=self.device,
            ).normal_()

            #generate fake atom
            fake_atom = self.generator(z, target_type, c_fake)
            generated_atoms.append(fake_atom)

            #update aa grids
            aa_grid = torch.where(repl[:, :, None, None, None], aa_grid,
                                  fake_atom)

        #generated_atoms = torch.stack(generated_atoms, dim=1)
        generated_atoms = torch.cat(generated_atoms, dim=1)

        generated_atoms_coords = avg_blob(
            generated_atoms,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma'),
            device=self.device,
        )

        b_energy, a_energy, d_energy, l_energy = self.get_energies_from_grid(
            aa_grid, energy_ndx)
        energy = b_energy + a_energy + d_energy + l_energy

        return generated_atoms_coords, energy
Пример #7
0
    def val(self):
        start = timer()

        resolution = self.cfg.getint('grid', 'resolution')
        grid_length = self.cfg.getfloat('grid', 'length')
        delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint(
            'grid', 'resolution')
        sigma_inp = self.cfg.getfloat('grid', 'sigma_inp')
        sigma_out = self.cfg.getfloat('grid', 'sigma_out')
        grid = torch.from_numpy(make_grid_np(delta_s,
                                             resolution)).to(self.device)

        n_mols = int(self.cfg.getint('universe', 'n_mols'))

        out_env = self.cfg.getboolean('model', 'out_env')
        val_bs = self.cfg.getint('validate', 'batchsize')

        samples_inp = self.data.samples_val_inp
        pos_dict = {}
        for sample in samples_inp:
            for a in sample.atoms:
                pos_dict[a] = a.pos

        g = Mol_N_Generator_AA(self.data,
                               train=False,
                               rand_rot=False,
                               n_mols=n_mols)
        all_elems = list(g)

        try:
            self.generator.eval()
            self.critic.eval()

            for o in range(0, self.cfg.getint('validate', 'n_gibbs')):
                for ndx in range(0, len(all_elems), val_bs):
                    with torch.no_grad():
                        batch = all_elems[ndx:min(ndx +
                                                  val_bs, len(all_elems))]

                        inp_positions_intra = np.array(
                            [d['inp_positions_intra'] for d in batch])
                        inp_intra_featvec = np.array(
                            [d['inp_intra_featvec'] for d in batch])

                        inp_positions_intra = torch.from_numpy(
                            inp_positions_intra).to(self.device).float()
                        inp_blobbs_intra = self.to_voxel(
                            inp_positions_intra, grid, sigma_inp)

                        features = torch.from_numpy(
                            inp_intra_featvec[:, :, :, None, None, None]).to(
                                self.device) * inp_blobbs_intra[:, :,
                                                                None, :, :, :]
                        features = torch.sum(features, 1)

                        inp_positions_inter = np.array(
                            [d['inp_positions_inter'] for d in batch])
                        inp_inter_featvec = np.array(
                            [d['inp_inter_featvec'] for d in batch])

                        inp_positions_inter = torch.from_numpy(
                            inp_positions_inter).to(self.device).float()
                        inp_blobbs_inter = self.to_voxel(
                            inp_positions_inter, grid, sigma_inp)

                        features_inp_inter = torch.from_numpy(
                            inp_inter_featvec[:, :, :, None, None, None]).to(
                                self.device) * inp_blobbs_inter[:, :,
                                                                None, :, :, :]
                        features_inp_inter = torch.sum(features_inp_inter, 1)

                        gen_input = torch.cat((features, features_inp_inter),
                                              1)
                        """
                        out_positions_inter = np.array([d['out_positions_inter'] for d in batch])
                        out_inter_featvec = np.array([d['out_inter_featvec'] for d in batch])

                        out_positions_inter = torch.from_numpy(out_positions_inter).to(self.device).float()
                        out_blobbs_inter = self.to_voxel(out_positions_inter, grid, sigma_inp)

                        features_out_inter = torch.from_numpy(out_inter_featvec[:, :, :, None, None, None]).to(self.device) * out_blobbs_inter[:, :, None, :, :, :]
                        features_out_inter = torch.sum(features_out_inter, 1)

                        #features = torch.cat((features, features_out_inter), 1)

                        gen_input = torch.cat((features, features_out_inter), 1)
                        """

                        batch_mols = np.array([d['inp_mols'] for d in batch])

                        #elems, energy_ndx_inp, energy_ndx_out = val_batch
                        #features, _, inp_coords_intra, inp_coords = elems
                        if self.z_dim != 0:
                            z = torch.empty(
                                [features.shape[0], self.z_dim],
                                dtype=torch.float32,
                                device=self.device,
                            ).normal_()

                            fake_mol = self.generator(z, gen_input)
                        else:
                            fake_mol = self.generator(gen_input)

                        coords = avg_blob(
                            fake_mol,
                            res=resolution,
                            width=grid_length,
                            sigma=sigma_out,
                            device=self.device,
                        )
                        for positions, mols in zip(coords, batch_mols):
                            positions = positions.detach().cpu().numpy()
                            positions = np.dot(positions, mols[0].rot_mat.T)
                            atoms = []
                            for mol in mols:
                                atoms += mol.atoms
                            for pos, atom in zip(positions, atoms):
                                atom.pos = pos + mol.com

            samples_dir = self.out.output_dir / "samples"
            samples_dir.mkdir(exist_ok=True)

            for sample in self.data.samples_val_inp:
                #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro"))
                sample.write_aa_gro_file(samples_dir / (sample.name + ".gro"))
                for a in sample.atoms:
                    a.pos = pos_dict[a]
                    #pos_dict[a] = a.pos
                #sample.kick_beads()
        finally:
            self.generator.train()
            self.critic.train()
            print("validation took ", timer() - start, "secs")
Пример #8
0
    def predict(self, elems, initial, energy_ndx):

        aa_grid, cg_features = initial

        generated_atoms = []
        for target_type, repl, bond_ndx, angle_ndx1, angle_ndx2, dih_ndx, lj_ndx in zip(
                *elems):
            fake_aa_features = self.featurize(aa_grid, bond_ndx, angle_ndx1,
                                              angle_ndx2, dih_ndx, lj_ndx)
            c_fake = torch.cat([fake_aa_features, cg_features], 1)
            #target_type = target_type.repeat(self.bs, 1)
            #print(target_type.size())
            z = torch.empty(
                [target_type.shape[0], self.z_dim],
                dtype=torch.float32,
                device=self.device,
            ).normal_()

            #generate fake atom
            fake_atom = self.generator(z, target_type, c_fake)
            generated_atoms.append(fake_atom)

            print(target_type)
            print(bond_ndx)
            env_coords = avg_blob(
                aa_grid,
                res=self.cfg.getint('grid', 'resolution'),
                width=self.cfg.getfloat('grid', 'length'),
                sigma=self.cfg.getfloat('grid', 'sigma'),
                device=self.device,
            )

            ndx2 = bond_ndx[:, :, 2]
            bond_atoms = [a[n] for n, a in zip(ndx2, env_coords)]
            #feature_grid = c_real.detach().cpu().numpy()

            fig = plt.figure(figsize=(10, 10))
            #_, _, nx, ny, nz = _feature_grid.shape
            res = self.cfg.getint('grid', 'resolution')
            nx, ny, nz = self.cfg.getint('grid',
                                         'resolution'), self.cfg.getint(
                                             'grid',
                                             'resolution'), self.cfg.getint(
                                                 'grid', 'resolution')

            target_coord = avg_blob(
                fake_atom,
                res=self.cfg.getint('grid', 'resolution'),
                width=self.cfg.getfloat('grid', 'length'),
                sigma=self.cfg.getfloat('grid', 'sigma'),
                device=self.device,
            )

            ds = self.cfg.getfloat('grid', 'length') / self.cfg.getint(
                'grid', 'resolution')
            for k in range(0, 3):
                ax = fig.add_subplot(2, 3, k + 1, projection='3d')
                ax.scatter(target_coord[0, 0, 0],
                           target_coord[0, 0, 1],
                           target_coord[0, 0, 2],
                           alpha=0.5,
                           s=50,
                           marker='o',
                           c="red")

                for a in bond_atoms:
                    ax.scatter(a[0, 0],
                               a[0, 1],
                               a[0, 2],
                               alpha=0.5,
                               s=50,
                               marker='o',
                               c="yellow")

                for x in range(0, nx):
                    for y in range(0, ny):
                        for zz in range(0, nz):
                            if c_fake[0, k, x, y,
                                      zz] > 0.5 and c_fake[0, k, x, y,
                                                           zz] < 1.1:
                                ax.scatter((x + 0.5 - res / 2) * ds,
                                           (y + 0.5 - res / 2) * ds,
                                           (zz + 0.5 - res / 2) * ds,
                                           alpha=np.minimum(
                                               c_fake[0, k, x, y, zz].numpy(),
                                               1.0),
                                           s=25,
                                           marker='o',
                                           c="black")
                #for c in env_coords[0]:
                #ax.scatter(c[0], c[1], c[2], s=25, marker='o', c="blue")

                ax.set_xlim([-0.6, 0.6])
                ax.set_ylim([-0.6, 0.6])
                ax.set_zlim([-0.6, 0.6])

            ax = fig.add_subplot(2, 3, 5, projection='3d')
            for x in range(0, nx):
                for y in range(0, ny):
                    for zz in range(0, nz):
                        if fake_atom[0, 0, x, y,
                                     zz] > 0.01 and fake_atom[0, 0, x, y,
                                                              zz] < 0.05:
                            ax.scatter(
                                (x + 0.5 - res / 2) * ds,
                                (y + 0.5 - res / 2) * ds,
                                (zz + 0.5 - res / 2) * ds,
                                alpha=np.minimum(fake_atom[0, 0, x, y, zz],
                                                 1.0).numpy() + 0.4,
                                s=25,
                                marker='o',
                                c="red")
            ax.scatter(target_coord[0, 0, 0],
                       target_coord[0, 0, 1],
                       target_coord[0, 0, 2],
                       s=80,
                       marker='o',
                       c="blue")
            ax.set_xlim([-0.6, 0.6])
            ax.set_ylim([-0.6, 0.6])
            ax.set_zlim([-0.6, 0.6])

            # plt.savefig("lj_grid.pdf")
            plt.show()

            #update aa grids
            aa_grid = torch.where(repl[:, :, None, None, None], aa_grid,
                                  fake_atom)

        #generated_atoms = torch.stack(generated_atoms, dim=1)
        generated_atoms = torch.cat(generated_atoms, dim=1)

        generated_atoms_coords = avg_blob(
            generated_atoms,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma'),
            device=self.device,
        )

        b_energy, a_energy, d_energy, l_energy = self.get_energies_from_grid(
            aa_grid, energy_ndx)
        energy = b_energy + a_energy + d_energy + l_energy
        #print(b_energy, a_energy, d_energy, l_energy )
        return generated_atoms_coords, energy
Пример #9
0
    def train_step_gen(self, elems, energy_ndx_inp, energy_ndx_out, backprop=True):

        features, target, inp_coords_intra, inp_coords, out_coords_inter = elems

        g_loss = torch.zeros([], dtype=torch.float32, device=self.device)

        if self.z_dim != 0:
            z = torch.empty(
                [features.shape[0], self.z_dim],
                dtype=torch.float32,
                device=self.device,
            ).normal_()

            fake_mol = self.generator(z, features)
        else:
            fake_mol = self.generator(features)

        if self.cond:
            fake_data = torch.cat([fake_mol, features], dim=1)
        else:
            fake_data = fake_mol

        critic_fake = self.critic(fake_data)

        #loss
        g_wass = self.generator_loss(critic_fake)
        #print("g_wass", g_wass)
        g_overlap = self.overlap_loss(features[:, :self.ff_inp.n_atom_chns], fake_mol[:, :self.ff_out.n_atoms])
        if self.use_ol:
            g_loss += g_wass + self.ol_weight * g_overlap
        else:
            g_loss += g_wass

        #g_loss = g_overlap

        #real_atom_grid = torch.where(repl[:, :, None, None, None], atom_grid, target_atom[:, None, :, :, :])
        #fake_atom_grid = torch.where(repl[:, :, None, None, None], atom_grid, fake_atom)

        e_bond_out, e_angle_out, e_dih_out, e_lj_out = self.get_energies_out(fake_mol, out_coords_inter, energy_ndx_out)
        e_bond_inp, e_angle_inp, e_dih_inp, e_lj_inp = self.get_energies_inp(inp_coords_intra, inp_coords, energy_ndx_inp)

        if self.use_energy:
            if self.prior_mode == 'match':

                e_bond_out_target, e_angle_out_target, e_dih_out_target, e_lj_out_target = self.get_energies_out(target, out_coords_inter, energy_ndx_out)

                #print("target")
                #print(e_bond_out_target, e_angle_out_target, e_dih_out_target, e_lj_out_target)
                #print("gen")
                #print(e_bond_out, e_angle_out, e_dih_out, e_lj_out)
                b_loss = torch.mean(torch.abs(e_bond_out_target - e_bond_out))
                a_loss = torch.mean(torch.abs(e_angle_out_target - e_angle_out))
                d_loss = torch.mean(torch.abs(e_dih_out_target - e_dih_out))
                l_loss = torch.mean(torch.abs(e_lj_out_target - e_lj_out))
                g_loss += self.energy_weight() * (b_loss + a_loss + d_loss + l_loss)
            elif self.prior_mode == 'min':
                g_loss += self.energy_weight() * (e_bond_out + e_angle_out + e_dih_out + e_lj_out)

        if self.recon:
            out_coords = avg_blob(
                fake_mol,
                res=self.cfg.getint('grid', 'resolution'),
                width=self.cfg.getfloat('grid', 'length'),
                sigma=self.cfg.getfloat('grid', 'sigma_out'),
                device=self.device,
            )
            rec_loss = self.reconstruction_loss(out_coords, inp_coords_intra)
            g_loss += self.recon_weight * rec_loss
        else:
            rec_loss = torch.zeros([], dtype=torch.float32, device=self.device)




        #g_loss = g_wass + self.prior_weight() * energy_loss
        #g_loss = g_wass

        if backprop:
            self.opt_generator.zero_grad()
            g_loss.backward()
            #for param in self.generator.parameters():
            #    print(param.grad)
            self.opt_generator.step()


        g_loss_dict = {"Generator/wasserstein": g_wass.detach().cpu().numpy(),
                       "Generator/e_bond_out": e_bond_out.detach().cpu().numpy(),
                       "Generator/e_angle_out": e_angle_out.detach().cpu().numpy(),
                       "Generator/e_dih_out": e_dih_out.detach().cpu().numpy(),
                       "Generator/e_lj_out": e_lj_out.detach().cpu().numpy(),
                       "Generator/e_bond_inp": e_bond_inp.detach().cpu().numpy(),
                       "Generator/e_angle_inp": e_angle_inp.detach().cpu().numpy(),
                       "Generator/e_dih_inp": e_dih_inp.detach().cpu().numpy(),
                       "Generator/e_lj_inp": e_lj_inp.detach().cpu().numpy(),
                       "Generator/overlap": g_overlap.detach().cpu().numpy(),
                       "Generator/rec": rec_loss.detach().cpu().numpy()}




        return g_loss_dict
Пример #10
0
    def val(self):
        start = timer()

        resolution = self.cfg.getint('grid', 'resolution')
        grid_length = self.cfg.getfloat('grid', 'length')
        delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint('grid', 'resolution')
        sigma_inp = self.cfg.getfloat('grid', 'sigma_inp')
        sigma_out = self.cfg.getfloat('grid', 'sigma_out')
        grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device)



        out_env = self.cfg.getboolean('model', 'out_env')
        val_bs = self.cfg.getint('validate', 'batchsize')

        rot_mtxs = torch.from_numpy(rot_mtx_batch(val_bs)).to(self.device).float()
        rot_mtxs_transposed = torch.from_numpy(rot_mtx_batch(val_bs, transpose=True)).to(self.device).float()

        samples_inp = self.data.samples_val_inp
        pos_dict = {}
        for sample in samples_inp:
            for a in sample.atoms:
                pos_dict[a] = a.pos

        #generators = []
        #for n in range(0, self.cfg.getint('validate', 'n_gibbs')):
        #    generators.append(iter(Mol_Generator_AA(self.data, train=False, rand_rot=False)))
        #all_elems = list(g)


        try:
            self.generator.eval()
            self.critic.eval()

            for n in range(0, self.cfg.getint('validate', 'n_gibbs')):
                g = iter(Mol_Rec_Generator(self.data, train=False, rand_rot=False))
                for d in g:
                    with torch.no_grad():
                        #batch = all_elems[ndx:min(ndx + val_bs, len(all_elems))]

                        inp_positions = np.array([d['positions']])
                        #inp_featvec = np.array([d['inp_intra_featvec']])

                        inp_positions = torch.matmul(torch.from_numpy(inp_positions).to(self.device).float(), rot_mtxs)
                        aa_grid = self.to_voxel(inp_positions, grid, sigma_inp)

                        #features = torch.from_numpy(inp_featvec[:, :, :, None, None, None]).to(self.device) * inp_blobbs[:, :, None, :, :, :]
                        #features = torch.sum(features, 1)

                        mol = d['mol']

                        elems = (d['featvec'], d['repl'])
                        elems = self.transpose(self.insert_dim(self.to_tensor(elems)))

                        energy_ndx = (d['bond_ndx'], d['angle_ndx'], d['dih_ndx'], d['lj_ndx'])
                        energy_ndx = self.repeat(self.to_tensor(energy_ndx), val_bs)

                        generated_atoms = []
                        for featvec, repl in zip(*elems):
                            features = torch.sum(aa_grid[:, :, None, :, :, :] * featvec[:, :, :, None, None, None], 1)

                            # generate fake atom
                            if self.z_dim != 0:
                                z = torch.empty(
                                    [features.shape[0], self.z_dim],
                                    dtype=torch.float32,
                                    device=self.device,
                                ).normal_()

                                fake_atom = self.generator(z, features)
                            else:
                                fake_atom = self.generator(features)
                            generated_atoms.append(fake_atom)

                            # update aa grids
                            aa_grid = torch.where(repl[:, :, None, None, None], aa_grid, fake_atom)

                        # generated_atoms = torch.stack(generated_atoms, dim=1)
                        generated_atoms = torch.cat(generated_atoms, dim=1)

                        coords = avg_blob(
                            generated_atoms,
                            res=self.cfg.getint('grid', 'resolution'),
                            width=self.cfg.getfloat('grid', 'length'),
                            sigma=self.cfg.getfloat('grid', 'sigma_out'),
                            device=self.device,
                        )

                        coords = torch.matmul(coords, rot_mtxs_transposed)
                        coords = torch.sum(coords, 0) / val_bs

                        #for positions, mol in zip(coords, mols):
                        positions = coords.detach().cpu().numpy()
                        positions = np.dot(positions, mol.rot_mat.T)
                        for pos, atom in zip(positions, mol.atoms):
                            atom.pos = pos + mol.com

                samples_dir = self.out.output_dir / "samples"
                samples_dir.mkdir(exist_ok=True)

                for sample in self.data.samples_val_inp:
                    #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro"))
                    sample.write_aa_gro_file(samples_dir / (sample.name + "_" +str(n) + ".gro"))
                    for a in sample.atoms:
                        a.pos = pos_dict[a]
                        #pos_dict[a] = a.pos
                #sample.kick_beads()
        finally:
            self.generator.train()
            self.critic.train()
            print("validation took ", timer()-start, "secs")
Пример #11
0
    def dstr_loss(self, real_atom_grid, fake_atom_grid, energy_ndx):
        real_coords = avg_blob(
            real_atom_grid,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma_out'),
            device=self.device,
        )
        fake_coords = avg_blob(
            fake_atom_grid,
            res=self.cfg.getint('grid', 'resolution'),
            width=self.cfg.getfloat('grid', 'length'),
            sigma=self.cfg.getfloat('grid', 'sigma_out'),
            device=self.device,
        )
        bond_ndx, angle_ndx, dih_ndx, lj_ndx = energy_ndx

        if bond_ndx.size()[1] and bond_ndx.size()[1]:
            #b_dstr_real = self.gauss_hist_bond(self.energy.bond_dstr(real_coords, bond_ndx))
            #b_dstr_fake = self.gauss_hist_bond(self.energy.bond_dstr(fake_coords, bond_ndx))
            b_dstr_inp = self.gauss_hist_bond(real_coords, bond_ndx)
            b_dstr_out = self.gauss_hist_bond(fake_coords, bond_ndx)
            b_dstr_avg = 0.5 * (b_dstr_inp + b_dstr_out)

            b_dstr_loss = 0.5 * ((b_dstr_inp * (b_dstr_inp / b_dstr_avg).log()).sum(0) + (b_dstr_out * (b_dstr_out / b_dstr_avg).log()).sum(0))

            if self.step % 50 == 0:
                fig = plt.figure()
                ax = plt.gca()
                x = [h * 0.4/64 for h in range(0,64)]
                ax.plot(x, b_dstr_inp.detach().cpu().numpy()[:,0], label='inp')
                ax.plot(x, b_dstr_out.detach().cpu().numpy()[:,0], label='out')
                #ax.plot(x, a_dstr_avg.detach().cpu().numpy()[:,0], label='avg')
                #ax.text(0.1, 0.1, "JSD: "+str(a_dstr_loss.detach().cpu().numpy()))
                self.out.add_fig("bond", fig, global_step=self.step)
                plt.close(fig)


        else:
            b_dstr_loss = torch.zeros([], dtype=torch.float32, device=self.device)
        if angle_ndx.size()[1] and angle_ndx.size()[1]:
            a_dstr_inp = self.gauss_hist_angle(real_coords, angle_ndx)
            a_dstr_out = self.gauss_hist_angle(fake_coords, angle_ndx)
            a_dstr_avg = 0.5 * (a_dstr_inp + a_dstr_out)

            a_dstr_loss = 0.5 * ((a_dstr_inp * (a_dstr_inp / a_dstr_avg).log()).sum(0) + (a_dstr_out * (a_dstr_out / a_dstr_avg).log()).sum(0))

            if self.step % 50 == 0:
                fig = plt.figure()
                ax = plt.gca()
                x = [h * 180.0/64 for h in range(0,64)]
                ax.plot(x, a_dstr_inp.detach().cpu().numpy()[:,0], label='inp')
                ax.plot(x, a_dstr_out.detach().cpu().numpy()[:,0], label='out')
                #ax.plot(x, a_dstr_avg.detach().cpu().numpy()[:,0], label='avg')
                #ax.text(0.1, 0.1, "JSD: "+str(a_dstr_loss.detach().cpu().numpy()))
                self.out.add_fig("angle", fig, global_step=self.step)
                plt.close(fig)


            #print(a_dstr_loss)
            #print(a_dstr_loss.size())

        else:
            a_dstr_loss = torch.zeros([], dtype=torch.float32, device=self.device)
        if dih_ndx.size()[1] and dih_ndx.size()[1]:
            d_dstr_inp = self.gauss_hist_bond(real_coords, dih_ndx)
            d_dstr_out = self.gauss_hist_bond(fake_coords, dih_ndx)
            d_dstr_avg = 0.5 * (d_dstr_inp + d_dstr_out)

            d_dstr_loss = 0.5 * ((d_dstr_inp * (d_dstr_inp / d_dstr_avg).log()).sum(0) + (d_dstr_out * (d_dstr_out / d_dstr_avg).log()).sum(0))

            if self.step % 50 == 0:
                fig = plt.figure()
                ax = plt.gca()
                x = [h * 180.0/64 for h in range(0,64)]
                ax.plot(x, d_dstr_inp.detach().cpu().numpy()[:,0], label='ref')
                ax.plot(x, d_dstr_out.detach().cpu().numpy()[:,0], label='fake')
                #ax.plot(x, a_dstr_avg.detach().cpu().numpy()[:,0], label='avg')
                #ax.text(0.1, 0.1, "JSD: "+str(a_dstr_loss.detach().cpu().numpy()))
                self.out.add_fig("dih", fig, global_step=self.step)
                plt.close(fig)

            #print(d_dstr_loss)
            #print(d_dstr_loss.size())
        else:
            d_dstr_loss = torch.zeros([], dtype=torch.float32, device=self.device)


        if lj_ndx.size()[1] and lj_ndx.size()[1]:
            nb_dstr_inp = self.gauss_hist_nb(real_coords, lj_ndx)
            nb_dstr_out = self.gauss_hist_nb(fake_coords, lj_ndx)
            nb_dstr_avg = 0.5 * (nb_dstr_inp + nb_dstr_out)

            nb_dstr_loss = 0.5 * ((nb_dstr_inp * (nb_dstr_inp / nb_dstr_avg).log()).sum(0) + (nb_dstr_out * (nb_dstr_out / nb_dstr_avg).log()).sum(0))

            if self.step % 50 == 0:
                fig = plt.figure()
                ax = plt.gca()
                x = [h * 2.0/64 for h in range(0,64)]
                ax.plot(x, nb_dstr_inp.detach().cpu().numpy()[:,0], label='inp')
                ax.plot(x, nb_dstr_out.detach().cpu().numpy()[:,0], label='out')
                #ax.plot(x, a_dstr_avg.detach().cpu().numpy()[:,0], label='avg')
                #ax.text(0.1, 0.1, "JSD: "+str(a_dstr_loss.detach().cpu().numpy()))
                self.out.add_fig("nonbonded", fig, global_step=self.step)
                plt.close(fig)

            #print(b_dstr_loss)
            #print(b_dstr_loss.size())
        else:
            nb_dstr_loss = torch.zeros([], dtype=torch.float32, device=self.device)
        #print(torch.sum(b_dstr_loss))
        #print(torch.sum(a_dstr_loss))
        #print(torch.sum(d_dstr_loss))
        #print(torch.sum(nb_dstr_loss))

        return torch.sum(b_dstr_loss), torch.sum(a_dstr_loss), torch.sum(d_dstr_loss), torch.sum(nb_dstr_loss)
Пример #12
0
    def val(self):
        resolution = self.cfg.getint('grid', 'resolution')
        grid_length = self.cfg.getfloat('grid', 'length')
        delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint(
            'grid', 'resolution')
        sigma_aa = self.cfg.getfloat('grid', 'sigma_aa')
        sigma_cg = self.cfg.getfloat('grid', 'sigma_cg')
        grid = torch.from_numpy(make_grid_np(delta_s,
                                             resolution)).to(self.device)

        g = Mol_Generator_AA(self.data, train=False, rand_rot=False)
        all_elems = list(g)

        try:
            self.generator.eval()
            self.critic.eval()

            for ndx in range(0, len(all_elems), self.bs):
                with torch.no_grad():
                    batch = all_elems[ndx:min(ndx + self.bs, len(all_elems))]

                    aa_positions_intra = np.array(
                        [d['aa_positions_intra'] for d in batch])
                    aa_intra_featvec = np.array(
                        [d['aa_intra_featvec'] for d in batch])

                    mols = np.array([d['aa_mol'] for d in batch])

                    aa_positions_intra = torch.from_numpy(
                        aa_positions_intra).to(self.device).float()
                    aa_blobbs_intra = self.to_voxel(aa_positions_intra, grid,
                                                    sigma_aa)

                    #print(aa_intra_featvec[:, :, :, None, None, None].shape)
                    #print(aa_blobbs_intra[:, :, None, :, :, :].size())
                    features = torch.from_numpy(
                        aa_intra_featvec[:, :, :, None, None, None]).to(
                            self.device) * aa_blobbs_intra[:, :, None, :, :, :]
                    features = torch.sum(features, 1)

                    if self.n_env_mols:
                        aa_positions_inter = np.array(
                            [d['aa_positions_inter'] for d in batch])
                        aa_inter_featvec = np.array(
                            [d['aa_inter_featvec'] for d in batch])
                        aa_positions_inter = torch.from_numpy(
                            aa_positions_inter).to(self.device).float()
                        aa_blobbs_inter = self.to_voxel(
                            aa_positions_inter, grid, sigma_aa)
                        features_inter = torch.from_numpy(
                            aa_inter_featvec[:, :, :, None, None, None]).to(
                                self.device) * aa_blobbs_inter[:, :,
                                                               None, :, :, :]
                        features_inter = torch.sum(features_inter, 1)
                        features = torch.cat((features, features_inter), 1)

                    #elems, energy_ndx_aa, energy_ndx_cg = val_batch
                    #features, _, aa_coords_intra, aa_coords = elems
                    if self.z_dim != 0:
                        z = torch.empty(
                            [features.shape[0], self.z_dim],
                            dtype=torch.float32,
                            device=self.device,
                        ).normal_()

                        fake_mol = self.generator(z, features)
                    else:
                        fake_mol = self.generator(features)

                    fake_mol = fake_mol[:, :self.ff_cg.n_atoms]
                    coords = avg_blob(
                        fake_mol,
                        res=resolution,
                        width=grid_length,
                        sigma=sigma_cg,
                        device=self.device,
                    )
                    for positions, mol in zip(coords, mols):
                        positions = positions.detach().cpu().numpy()
                        positions = np.dot(positions, mol.rot_mat.T)
                        for pos, bead in zip(positions, mol.beads):
                            bead.pos = pos + mol.com

            samples_dir = self.out.output_dir / "samples"
            samples_dir.mkdir(exist_ok=True)

            for sample in self.data.samples_val_aa:
                #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro"))
                sample.write_gro_file(samples_dir / (sample.name + ".gro"))
        finally:
            self.generator.train()
            self.critic.train()