def __init__(self, data, cfg, train=True): self.data = data self.train = train self.n_env_mols = int(cfg.getint('universe', 'n_env_mols')) self.n_mols = int(cfg.getint('universe', 'n_mols')) g = Mol_N_Generator(data, train=train, rand_rot=False, n_mols=self.n_mols) self.elems = g.all_elems() self.resolution = cfg.getint('grid', 'resolution') self.delta_s = cfg.getfloat('grid', 'length') / cfg.getint( 'grid', 'resolution') self.sigma_inp = cfg.getfloat('grid', 'sigma_inp') self.sigma_out = cfg.getfloat('grid', 'sigma_out') if cfg.getboolean('training', 'rand_rot'): self.rand_rot = True print("using random rotations during training...") else: self.rand_rot = False self.align = int(cfg.getboolean('universe', 'align')) self.out_env = cfg.getboolean('model', 'out_env') self.grid = make_grid_np(self.delta_s, self.resolution)
def __init__(self, data, cfg, train=True): self.data = data generators = [] generators.append( Recurrent_Generator(data, hydrogens=False, gibbs=False, train=train, rand_rot=False)) #generators.append(Recurrent_Generator(data, hydrogens=True, gibbs=False, train=train, rand_rot=False)) generators.append( Recurrent_Generator(data, hydrogens=False, gibbs=True, train=train, rand_rot=False)) #generators.append(Recurrent_Generator(data, hydrogens=True, gibbs=True, train=train, rand_rot=False)) if cfg.getboolean('training', 'hydrogens'): generators.append( Recurrent_Generator(data, hydrogens=True, gibbs=False, train=train, rand_rot=False)) generators.append( Recurrent_Generator(data, hydrogens=True, gibbs=True, train=train, rand_rot=False)) self.elems = [] for g in generators: self.elems += g.all_elems() self.resolution = cfg.getint('grid', 'resolution') self.delta_s = cfg.getfloat('grid', 'length') / cfg.getint( 'grid', 'resolution') self.sigma = cfg.getfloat('grid', 'sigma') if cfg.getboolean('training', 'rand_rot'): self.rand_rot = True print("using random rotations during training...") else: self.rand_rot = False self.align = int(cfg.getboolean('universe', 'align')) self.grid = make_grid_np(self.delta_s, self.resolution)
def val(self): start = timer() resolution = self.cfg.getint('grid', 'resolution') grid_length = self.cfg.getfloat('grid', 'length') delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint( 'grid', 'resolution') sigma_inp = self.cfg.getfloat('grid', 'sigma_inp') sigma_out = self.cfg.getfloat('grid', 'sigma_out') grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device) n_mols = int(self.cfg.getint('universe', 'n_mols')) out_env = self.cfg.getboolean('model', 'out_env') val_bs = self.cfg.getint('validate', 'batchsize') samples_inp = self.data.samples_val_inp pos_dict = {} for sample in samples_inp: for a in sample.atoms: pos_dict[a] = a.pos g = Mol_N_Generator_AA(self.data, train=False, rand_rot=False, n_mols=n_mols) all_elems = list(g) try: self.generator.eval() self.critic.eval() for o in range(0, self.cfg.getint('validate', 'n_gibbs')): for ndx in range(0, len(all_elems), val_bs): with torch.no_grad(): batch = all_elems[ndx:min(ndx + val_bs, len(all_elems))] inp_positions_intra = np.array( [d['inp_positions_intra'] for d in batch]) inp_intra_featvec = np.array( [d['inp_intra_featvec'] for d in batch]) inp_positions_intra = torch.from_numpy( inp_positions_intra).to(self.device).float() inp_blobbs_intra = self.to_voxel( inp_positions_intra, grid, sigma_inp) features = torch.from_numpy( inp_intra_featvec[:, :, :, None, None, None]).to( self.device) * inp_blobbs_intra[:, :, None, :, :, :] features = torch.sum(features, 1) inp_positions_inter = np.array( [d['inp_positions_inter'] for d in batch]) inp_inter_featvec = np.array( [d['inp_inter_featvec'] for d in batch]) inp_positions_inter = torch.from_numpy( inp_positions_inter).to(self.device).float() inp_blobbs_inter = self.to_voxel( inp_positions_inter, grid, sigma_inp) features_inp_inter = torch.from_numpy( inp_inter_featvec[:, :, :, None, None, None]).to( self.device) * inp_blobbs_inter[:, :, None, :, :, :] features_inp_inter = torch.sum(features_inp_inter, 1) gen_input = torch.cat((features, features_inp_inter), 1) """ out_positions_inter = np.array([d['out_positions_inter'] for d in batch]) out_inter_featvec = np.array([d['out_inter_featvec'] for d in batch]) out_positions_inter = torch.from_numpy(out_positions_inter).to(self.device).float() out_blobbs_inter = self.to_voxel(out_positions_inter, grid, sigma_inp) features_out_inter = torch.from_numpy(out_inter_featvec[:, :, :, None, None, None]).to(self.device) * out_blobbs_inter[:, :, None, :, :, :] features_out_inter = torch.sum(features_out_inter, 1) #features = torch.cat((features, features_out_inter), 1) gen_input = torch.cat((features, features_out_inter), 1) """ batch_mols = np.array([d['inp_mols'] for d in batch]) #elems, energy_ndx_inp, energy_ndx_out = val_batch #features, _, inp_coords_intra, inp_coords = elems if self.z_dim != 0: z = torch.empty( [features.shape[0], self.z_dim], dtype=torch.float32, device=self.device, ).normal_() fake_mol = self.generator(z, gen_input) else: fake_mol = self.generator(gen_input) coords = avg_blob( fake_mol, res=resolution, width=grid_length, sigma=sigma_out, device=self.device, ) for positions, mols in zip(coords, batch_mols): positions = positions.detach().cpu().numpy() positions = np.dot(positions, mols[0].rot_mat.T) atoms = [] for mol in mols: atoms += mol.atoms for pos, atom in zip(positions, atoms): atom.pos = pos + mol.com samples_dir = self.out.output_dir / "samples" samples_dir.mkdir(exist_ok=True) for sample in self.data.samples_val_inp: #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro")) sample.write_aa_gro_file(samples_dir / (sample.name + ".gro")) for a in sample.atoms: a.pos = pos_dict[a] #pos_dict[a] = a.pos #sample.kick_beads() finally: self.generator.train() self.critic.train() print("validation took ", timer() - start, "secs")
def validate(self, samples_dir=None): if samples_dir: samples_dir = self.out.output_dir / samples_dir make_dir(samples_dir) else: samples_dir = self.out.samples_dir stats = Stats(self.data, dir=samples_dir / "stats") print("Saving samples in {}".format(samples_dir), "...", end='') resolution = self.cfg.getint('grid', 'resolution') delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint( 'grid', 'resolution') sigma = self.cfg.getfloat('grid', 'sigma') #grid = make_grid_np(delta_s, resolution) grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device) rot_mtxs = torch.from_numpy(rot_mtx_batch(self.bs)).to( self.device).float() rot_mtxs_transposed = torch.from_numpy( rot_mtx_batch(self.bs, transpose=True)).to(self.device).float() data_generators = [] data_generators.append( iter( Generator(self.data, hydrogens=False, gibbs=False, train=False, rand_rot=False, pad_seq=False, ref_pos=True))) data_generators.append( iter( Generator(self.data, hydrogens=True, gibbs=False, train=False, rand_rot=False, pad_seq=False, ref_pos=True))) for m in range(self.n_gibbs): data_generators.append( iter( Generator(self.data, hydrogens=False, gibbs=True, train=False, rand_rot=False, pad_seq=False, ref_pos=False))) data_generators.append( iter( Generator(self.data, hydrogens=True, gibbs=True, train=False, rand_rot=False, pad_seq=False, ref_pos=False))) try: self.generator.eval() self.critic.eval() g = 0 for data_gen in data_generators: start = timer() for d in data_gen: with torch.no_grad(): aa_coords = torch.matmul( torch.from_numpy(d['aa_pos']).to( self.device).float(), rot_mtxs) cg_coords = torch.matmul( torch.from_numpy(d['cg_pos']).to( self.device).float(), rot_mtxs) #aa_coords = torch.from_numpy(d['aa_pos']).to(self.device).float() #cg_coords = torch.from_numpy(d['cg_pos']).to(self.device).float() aa_grid = self.to_voxel(aa_coords, grid, sigma) cg_grid = self.to_voxel(cg_coords, grid, sigma) cg_features = torch.from_numpy( d['cg_feat'][None, :, :, None, None, None]).to( self.device) * cg_grid[:, :, None, :, :, :] cg_features = torch.sum(cg_features, 1) initial = (aa_grid, cg_features) #elems = (d['target_type'], d['aa_feat'], d['repl']) elems = (d['target_type'], d['repl'], d['bonds_ndx_atom'], d['angles_ndx1_atom'], d['angles_ndx2_atom'], d['dihs_ndx_atom'], d['ljs_ndx_atom']) #elems = self.transpose(self.insert_dim(self.to_tensor(elems))) elems = self.transpose( self.repeat(self.to_tensor(elems))) energy_ndx = (d['bonds_ndx'], d['angles_ndx'], d['dihs_ndx'], d['ljs_ndx']) energy_ndx = self.repeat(self.to_tensor(energy_ndx)) new_coords, energies = self.predict( elems, initial, energy_ndx) #print(energies) ndx = energies.argmin() new_coords = torch.matmul(new_coords[ndx], rot_mtxs_transposed[ndx]) #new_coords = new_coords[ndx] new_coords = new_coords.detach().cpu().numpy() for c, a in zip(new_coords, d['atom_seq']): a.pos = d['loc_env'].rot_back(c) #a.ref_pos = d['loc_env'].rot_back(c) print(timer() - start) stats.evaluate(train=False, subdir=str(self.epoch) + "_" + str(g), save_samples=True) g += 1 #reset atom positions for sample in self.data.samples_val: #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro")) sample.kick_atoms() finally: self.generator.train() self.critic.train()
def val(self): start = timer() resolution = self.cfg.getint('grid', 'resolution') grid_length = self.cfg.getfloat('grid', 'length') delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint('grid', 'resolution') sigma_inp = self.cfg.getfloat('grid', 'sigma_inp') sigma_out = self.cfg.getfloat('grid', 'sigma_out') grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device) out_env = self.cfg.getboolean('model', 'out_env') val_bs = self.cfg.getint('validate', 'batchsize') rot_mtxs = torch.from_numpy(rot_mtx_batch(val_bs)).to(self.device).float() rot_mtxs_transposed = torch.from_numpy(rot_mtx_batch(val_bs, transpose=True)).to(self.device).float() samples_inp = self.data.samples_val_inp pos_dict = {} for sample in samples_inp: for a in sample.atoms: pos_dict[a] = a.pos #generators = [] #for n in range(0, self.cfg.getint('validate', 'n_gibbs')): # generators.append(iter(Mol_Generator_AA(self.data, train=False, rand_rot=False))) #all_elems = list(g) try: self.generator.eval() self.critic.eval() for n in range(0, self.cfg.getint('validate', 'n_gibbs')): g = iter(Mol_Rec_Generator(self.data, train=False, rand_rot=False)) for d in g: with torch.no_grad(): #batch = all_elems[ndx:min(ndx + val_bs, len(all_elems))] inp_positions = np.array([d['positions']]) #inp_featvec = np.array([d['inp_intra_featvec']]) inp_positions = torch.matmul(torch.from_numpy(inp_positions).to(self.device).float(), rot_mtxs) aa_grid = self.to_voxel(inp_positions, grid, sigma_inp) #features = torch.from_numpy(inp_featvec[:, :, :, None, None, None]).to(self.device) * inp_blobbs[:, :, None, :, :, :] #features = torch.sum(features, 1) mol = d['mol'] elems = (d['featvec'], d['repl']) elems = self.transpose(self.insert_dim(self.to_tensor(elems))) energy_ndx = (d['bond_ndx'], d['angle_ndx'], d['dih_ndx'], d['lj_ndx']) energy_ndx = self.repeat(self.to_tensor(energy_ndx), val_bs) generated_atoms = [] for featvec, repl in zip(*elems): features = torch.sum(aa_grid[:, :, None, :, :, :] * featvec[:, :, :, None, None, None], 1) # generate fake atom if self.z_dim != 0: z = torch.empty( [features.shape[0], self.z_dim], dtype=torch.float32, device=self.device, ).normal_() fake_atom = self.generator(z, features) else: fake_atom = self.generator(features) generated_atoms.append(fake_atom) # update aa grids aa_grid = torch.where(repl[:, :, None, None, None], aa_grid, fake_atom) # generated_atoms = torch.stack(generated_atoms, dim=1) generated_atoms = torch.cat(generated_atoms, dim=1) coords = avg_blob( generated_atoms, res=self.cfg.getint('grid', 'resolution'), width=self.cfg.getfloat('grid', 'length'), sigma=self.cfg.getfloat('grid', 'sigma_out'), device=self.device, ) coords = torch.matmul(coords, rot_mtxs_transposed) coords = torch.sum(coords, 0) / val_bs #for positions, mol in zip(coords, mols): positions = coords.detach().cpu().numpy() positions = np.dot(positions, mol.rot_mat.T) for pos, atom in zip(positions, mol.atoms): atom.pos = pos + mol.com samples_dir = self.out.output_dir / "samples" samples_dir.mkdir(exist_ok=True) for sample in self.data.samples_val_inp: #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro")) sample.write_aa_gro_file(samples_dir / (sample.name + "_" +str(n) + ".gro")) for a in sample.atoms: a.pos = pos_dict[a] #pos_dict[a] = a.pos #sample.kick_beads() finally: self.generator.train() self.critic.train() print("validation took ", timer()-start, "secs")
def val(self): resolution = self.cfg.getint('grid', 'resolution') grid_length = self.cfg.getfloat('grid', 'length') delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint( 'grid', 'resolution') sigma_aa = self.cfg.getfloat('grid', 'sigma_aa') sigma_cg = self.cfg.getfloat('grid', 'sigma_cg') grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device) g = Mol_Generator_AA(self.data, train=False, rand_rot=False) all_elems = list(g) try: self.generator.eval() self.critic.eval() for ndx in range(0, len(all_elems), self.bs): with torch.no_grad(): batch = all_elems[ndx:min(ndx + self.bs, len(all_elems))] aa_positions_intra = np.array( [d['aa_positions_intra'] for d in batch]) aa_intra_featvec = np.array( [d['aa_intra_featvec'] for d in batch]) mols = np.array([d['aa_mol'] for d in batch]) aa_positions_intra = torch.from_numpy( aa_positions_intra).to(self.device).float() aa_blobbs_intra = self.to_voxel(aa_positions_intra, grid, sigma_aa) #print(aa_intra_featvec[:, :, :, None, None, None].shape) #print(aa_blobbs_intra[:, :, None, :, :, :].size()) features = torch.from_numpy( aa_intra_featvec[:, :, :, None, None, None]).to( self.device) * aa_blobbs_intra[:, :, None, :, :, :] features = torch.sum(features, 1) if self.n_env_mols: aa_positions_inter = np.array( [d['aa_positions_inter'] for d in batch]) aa_inter_featvec = np.array( [d['aa_inter_featvec'] for d in batch]) aa_positions_inter = torch.from_numpy( aa_positions_inter).to(self.device).float() aa_blobbs_inter = self.to_voxel( aa_positions_inter, grid, sigma_aa) features_inter = torch.from_numpy( aa_inter_featvec[:, :, :, None, None, None]).to( self.device) * aa_blobbs_inter[:, :, None, :, :, :] features_inter = torch.sum(features_inter, 1) features = torch.cat((features, features_inter), 1) #elems, energy_ndx_aa, energy_ndx_cg = val_batch #features, _, aa_coords_intra, aa_coords = elems if self.z_dim != 0: z = torch.empty( [features.shape[0], self.z_dim], dtype=torch.float32, device=self.device, ).normal_() fake_mol = self.generator(z, features) else: fake_mol = self.generator(features) fake_mol = fake_mol[:, :self.ff_cg.n_atoms] coords = avg_blob( fake_mol, res=resolution, width=grid_length, sigma=sigma_cg, device=self.device, ) for positions, mol in zip(coords, mols): positions = positions.detach().cpu().numpy() positions = np.dot(positions, mol.rot_mat.T) for pos, bead in zip(positions, mol.beads): bead.pos = pos + mol.com samples_dir = self.out.output_dir / "samples" samples_dir.mkdir(exist_ok=True) for sample in self.data.samples_val_aa: #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro")) sample.write_gro_file(samples_dir / (sample.name + ".gro")) finally: self.generator.train() self.critic.train()
def val(self): resolution = self.cfg.getint('grid', 'resolution') grid_length = self.cfg.getfloat('grid', 'length') delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint( 'grid', 'resolution') sigma_aa = self.cfg.getfloat('grid', 'sigma_aa') sigma_cg = self.cfg.getfloat('grid', 'sigma_cg') grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device) g = Mol_Generator(self.data, train=True, rand_rot=False) all_elems = list(g) print("jetzt gehts los") n = 0 try: self.generator.eval() self.critic.eval() print("oha") for g in all_elems: mol = g['aa_mol'] aa_positions_intra = np.array([g['aa_positions_intra']]) aa_intra_featvec = np.array([g['aa_intra_featvec']]) aa_positions_intra = torch.from_numpy(aa_positions_intra).to( self.device).float() aa_blobbs_intra = self.to_voxel(aa_positions_intra, grid, sigma_aa) # print(aa_intra_featvec[:, :, :, None, None, None].shape) # print(aa_blobbs_intra[:, :, None, :, :, :].size()) features = torch.from_numpy( aa_intra_featvec[:, :, :, None, None, None]).to( self.device) * aa_blobbs_intra[:, :, None, :, :, :] features = torch.sum(features, 1) ol_min_glob = 100.0 print(n) for ndx in range(0, len(all_elems), self.bs): with torch.no_grad(): batch = all_elems[ndx:min(ndx + self.bs, len(all_elems))] #print(batch) cg_positions_intra = np.array( [d['cg_positions_intra'] for d in batch]) cg_positions_intra = torch.from_numpy( cg_positions_intra).to(self.device).float() target = self.to_voxel(cg_positions_intra, grid, sigma_cg) ol = self.overlap_loss(features, target) ol = ol.detach().cpu().numpy() ndx = ol.argmin() ol_min = ol[ndx] if ol_min < ol_min_glob: ol_min_glob = ol_min min_coords = np.array( [d['cg_positions_intra'] for d in batch])[ndx] #print(ol) """ coords = avg_blob( fake_mol, res=resolution, width=grid_length, sigma=sigma_cg, device=self.device,) """ min_coords = np.dot(min_coords, mol.rot_mat.T) for pos, bead in zip(min_coords, mol.beads): bead.pos = pos + mol.com n = n + 1 samples_dir = self.out.output_dir / "samples" samples_dir.mkdir(exist_ok=True) for sample in self.data.samples_train_aa: #sample.write_gro_file(samples_dir / (sample.name + str(self.step) + ".gro")) sample.write_gro_file(samples_dir / (sample.name + ".gro")) finally: self.generator.train() self.critic.train()