def cartesian_gaussian(): """Generate a 3D Cartesian Gaussian with random mean and covariance.""" rotation = stats.special_ortho_group(3).rvs() eigenvalues = stats.gamma(2).rvs(3) cov = rotation @ np.diag(eigenvalues) @ rotation.T mean = stats.norm.rvs(size=3) return stats.multivariate_normal(mean, cov)
def test_frozen_matrix(self): dim = 7 frozen = special_ortho_group(dim) rvs1 = frozen.rvs(random_state=1234) rvs2 = special_ortho_group.rvs(dim, random_state=1234) assert_equal(rvs1, rvs2)
def build_moment_index(self, name, mesh, grid_size): print(f'Building HuIndex for {name}') n_samples = 1000 for i in range(n_samples): print(f'{i}/{n_samples}') R = special_ortho_group(3).rvs() mask, offset, vt = self.mesh_to_mask(mesh, R, grid_size) hu = cv2.HuMoments(cv2.moments(np.float32(mask))) lower, upper = hu - abs(hu * 0.3), hu + abs(hu * 0.3) self.Index.insert(i, np.concatenate((lower, upper), axis = 0), obj = (R, offset, vt))
def pair_model(model, FLAGS, node_embed): # Generate a dataset where two atoms are very close to each other and everything else is very far # Indices for atoms atom_names = ["X", "C", "N", "O", "S"] residue_names = [ "ALA", "ARG", "ASN", "ASP", "CYS", "GLU", "GLN", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", ] energies_output_dict = {} def make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2): return f"{n_rotations}_{residue_name1}_{residue_name2}_{atom_name1}_{atom_name2}" # Save a copy of the node embed node_embed = node_embed[0] node_embed_orig = node_embed.clone() # Try different combinations for n_rotations in [5]: # Rotations so3 = special_ortho_group(3) rot_matrix_neg = so3.rvs( n_rotations) # number of random rotations to average residue_names_proc = ["ALA", "TYR", "LEU"] atom_names_proc = ["C", "N", "O"] for residue_name1, residue_name2 in itertools.product( residue_names_proc, repeat=2): for atom_name1, atom_name2 in itertools.product(atom_names_proc, repeat=2): eps = [] energies = [] residue_index1 = residue_names.index(residue_name1) residue_index2 = residue_names.index(residue_name2) atom_index1 = atom_names.index(atom_name1) atom_index2 = atom_names.index(atom_name2) for i in np.linspace(0.1, 1.0, 100): node_embed = node_embed_orig.clone() node_embed[-2, -3:] = torch.Tensor([1.0, 0.5, 0.5]) node_embed[-1, -3:] = torch.Tensor([1.0 + i, 0.5, 0.5]) node_embed[-1, 0] = residue_index1 node_embed[-2, 0] = residue_index2 node_embed[-1, 1] = atom_index1 node_embed[-2, 1] = atom_index2 node_embed[-1, 2] = 6 # res_counter node_embed[-2, 2] = 6 # res_counter node_embed = np.tile(node_embed[None, :, :], (n_rotations, 1, 1)) node_embed[:, :, -3:] = np.matmul(node_embed[:, :, -3:], rot_matrix_neg) node_embed_feed = torch.Tensor(node_embed).cuda() node_embed_feed[:, :, -3:] = node_embed_feed[:, :, -3:] - node_embed_feed[:, :, -3:].mean( dim=1, keepdim=True) energy = model.forward(node_embed_feed) # energy = energy.mean() eps.append(i * 10) energies.append(energy.item()) key = make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2) energies_output_dict[key] = (eps, energies) # Optionally make plots here -- potentially add conditions to avoid making too many plt.plot(eps, energies) plt.xlabel("Atom Distance") plt.ylabel("Energy") plt.title( f"{n_rotations} rots: {atom_name1}, {atom_name2} distance in {residue_name1}/{residue_name2}" ) plt.savefig( f"distance_plots/{n_rotations}_{atom_name1}_{atom_name2}_in_{residue_name1}_{residue_name2}_distance.png" ) plt.clf() # Back to outside output_path = osp.join(FLAGS.outdir, "atom_distances.p") pickle.dump(energies_output_dict, open(output_path, "wb"))
def new_model(model, FLAGS, node_embed): BATCH_SIZE = 120 pdb_name = FLAGS.pdb_name #'6mdw.0' pickle_file = f"/private/home/yilundu/dataset/mmcif/mmCIF/{pdb_name[1:3]}/{pdb_name}.p" (node_embed, ) = pickle.load(open(pickle_file, "rb")) par, child, pos, pos_exist, res, chis_valid = parse_dense_format( node_embed) angles = compute_dihedral(par, child, pos, pos_exist) chis_target_initial = angles[:, 4:8].copy( ) # dihedral for backbone (:4); dihedral for sidechain (4:8) NUM_RES = len(res) all_energies = np.empty( (NUM_RES, 4, 360)) # 4 is number of possible chi angles surface_core_type = [] for idx in range(NUM_RES): dist = np.sqrt(np.square(pos[idx:idx + 1, 2] - pos[:, 2]).sum(axis=1)) neighbors = (dist < 10).sum() if neighbors >= 24: surface_core_type.append("core") elif neighbors <= 16: surface_core_type.append("surface") else: surface_core_type.append("unlabeled") for idx in tqdm(range(NUM_RES)): for chi_num in range(4): if not chis_valid[idx, chi_num]: continue # init_angle = chis_target[idx, chi_num] for angle_deltas in batch(range(-180, 180, 3), BATCH_SIZE): pre_rot_node_embed_short = [] for angle_delta in angle_deltas: chis_target = chis_target_initial.copy( ) # make a local copy # modify the angle by angle_delta amount. rotate to chis_target chis_target[ idx, chi_num] += angle_delta # Set the specific chi angle to be the sampled value # pos_new is n residues x 20 atoms x 3 (xyz) pos_new = rotate_dihedral_fast(angles, par, child, pos, pos_exist, chis_target, chis_valid, idx) node_neg_embed = reencode_dense_format( node_embed, pos_new, pos_exist) # sort the atoms by how far away they are # sort key is the first atom on the sidechain pos_chosen = pos_new[idx, 4] close_idx = np.argsort( np.square(node_neg_embed[:, -3:] - pos_chosen).sum(axis=1)) # Grab the 64 closest atoms node_embed_short = node_neg_embed[ close_idx[:FLAGS.max_size]].copy() # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0 node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean( node_embed_short[:, -3:], axis=0) node_embed_short = torch.from_numpy( node_embed_short).float().cuda() pre_rot_node_embed_short.append( node_embed_short.unsqueeze(0)) pre_rot_node_embed_short = torch.stack( pre_rot_node_embed_short) # Now rotate all elements n_rotations = 100 so3 = special_ortho_group(3) rot_matrix = so3.rvs(n_rotations) # n x 3 x 3 node_embed_short = pre_rot_node_embed_short.repeat( 1, n_rotations, 1, 1) rot_matrix = torch.from_numpy(rot_matrix).float().cuda() node_embed_short[:, :, :, -3:] = torch.matmul( node_embed_short[:, :, :, -3:], rot_matrix) # (batch_size, n_rotations, 64, 20) # Compute the energies for the n_rotations * batch_size for this window of 64 atoms. # Batch the first two dimensions, then pull them apart aftewrads. node_embed_short = node_embed_short.reshape( node_embed_short.shape[0] * node_embed_short.shape[1], *node_embed_short.shape[2:], ) energies = model.forward(node_embed_short) # (12000, 1) # divide the batch dimension by the 10 things we just did energies = energies.reshape(BATCH_SIZE, -1) # (10, 200) # Average the energy across the n_rotations, but keeping batch-wise seperate energies = energies.mean(1) # (10, 1) # Save the result all_energies[idx, chi_num, angle_deltas] = energies.cpu().numpy() # Can use these for processing later. avg_chi_angle_energy = ( all_energies * chis_valid[:NUM_RES, :4, None]).sum(0) / np.expand_dims( chis_valid[:NUM_RES, :4].sum(0), 1) # normalize by how many times each chi angle occurs output = { "all_energies": all_energies, "chis_valid": chis_valid, "chis_target_initial": chis_target_initial, "avg_chi_angle_energy": avg_chi_angle_energy, # make four plots from this (4, 360), "res": res, "surface_core_type": surface_core_type, } # Dump the output output_path = osp.join(FLAGS.outdir, f"{pdb_name}_rot_energies.p") if not osp.exists(FLAGS.outdir): os.makedirs(FLAGS.outdir) pickle.dump(output, open(output_path, "wb"))
def make_tsne(model, FLAGS, node_embed): """ grab representations for each of the residues in a pdb """ pdb_name = FLAGS.pdb_name pickle_file = MMCIF_PATH + f"/mmCIF/{pdb_name[1:3]}/{pdb_name}.p" (node_embed, ) = pickle.load(open(pickle_file, "rb")) par, child, pos, pos_exist, res, chis_valid = parse_dense_format( node_embed) angles = compute_dihedral(par, child, pos, pos_exist) all_hiddens = [] all_energies = [] n_rotations = 2 so3 = special_ortho_group(3) rot_matrix = so3.rvs(n_rotations) rot_matrix = torch.from_numpy(rot_matrix).float().cuda() for idx in range(len(res)): # sort the atoms by how far away they are # sort key is the first atom on the sidechain pos_chosen = pos[idx, 4] close_idx = np.argsort( np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1)) # Grab the 64 closest atoms node_embed_short = node_embed[close_idx[:FLAGS.max_size]].copy() # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0 node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean( node_embed_short[:, -3:], axis=0) node_embed_short = torch.from_numpy(node_embed_short).float().cuda() node_embed_short = node_embed_short[None, :, :].repeat( n_rotations, 1, 1) node_embed_short[:, :, -3:] = torch.matmul(node_embed_short[:, :, -3:], rot_matrix) # Compute the energies for the n_rotations * batch_size for this window of 64 atoms. # Batch the first two dimensions, then pull them apart aftewrads. # node_embed_short = node_embed_short.reshape(node_embed_short.shape[0] * node_embed_short.shape[1], *node_embed_short.shape[2:]) energies, hidden = model.forward(node_embed_short, return_hidden=True) # (12000, 1) # all_hiddens.append(hidden.mean(0)) # mean over the rotations all_hiddens.append(hidden[0]) # take first rotation all_energies.append(energies[0]) surface_core_type = [] for idx in range(len(res)): # >16 c-beta neighbors within 10A of its own c-beta (or c-alpha for gly). hacked_pos = np.copy(pos) swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1) # (20, 59, 3) idxs_to_change = swap_hacked_pos[4] == [0, 0, 0] # (59, 3) swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change] hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1) dist = np.sqrt( np.square(hacked_pos_final[idx:idx + 1, 4] - hacked_pos_final[:, 4]).sum(axis=1)) neighbors = (dist < 10).sum() if neighbors >= 24: surface_core_type.append("core") elif neighbors <= 16: surface_core_type.append("surface") else: surface_core_type.append("unlabeled") output = { "res": res, "surface_core_type": surface_core_type, "all_hiddens": torch.stack(all_hiddens).cpu().numpy(), "all_energies": torch.stack(all_energies).cpu().numpy(), } # Dump the output output_path = osp.join(FLAGS.outdir, f"{pdb_name}_representations.p") if not osp.exists(FLAGS.outdir): os.makedirs(FLAGS.outdir) pickle.dump(output, open(output_path, "wb"))
def rotamer_trials(model, FLAGS, test_dataset): test_files = test_dataset.files random.shuffle(test_files) db = load_rotamor_library() so3 = special_ortho_group(3) node_embed_evals = [] nminibatch = 4 if FLAGS.ensemble > 1: models = model # The three different sampling methods are weighted_gauss, gmm, rosetta rotamer_scores_total = [] surface_scores_total = [] buried_scores_total = [] amino_recovery_total = {} for k, v in kvs.items(): amino_recovery_total[k.lower()] = [] counter = 0 rotations = FLAGS.rotations for test_file in tqdm(test_files): (node_embed, ) = pickle.load(open(test_file, "rb")) node_embed_original = node_embed par, child, pos, pos_exist, res, chis_valid = parse_dense_format( node_embed) angles = compute_dihedral(par, child, pos, pos_exist) amino_recovery = {} for k, v in kvs.items(): amino_recovery[k.lower()] = [] if node_embed is None: continue rotamer_scores = [] surface_scores = [] buried_scores = [] types = [] gt_chis = [] node_embed_evals = [] neg_chis = [] valid_chi_idxs = [] res_names = [] neg_sample = FLAGS.neg_sample n_amino = pos.shape[0] amino_recovery_curr = {} for idx in range(1, n_amino - 1): res_name = res[idx] if res_name == "gly" or res_name == "ala": continue res_names.append(res_name) gt_chis.append(angles[idx, 4:8]) valid_chi_idxs.append(chis_valid[idx, :4]) hacked_pos = np.copy(pos) swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1) # (20, 59, 3) idxs_to_change = swap_hacked_pos[4] == [0, 0, 0] # (59, 3) swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][ idxs_to_change] hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1) neighbors = np.linalg.norm( pos[idx:idx + 1, 4] - hacked_pos_final[:, 4], axis=1) < 10 neighbors = neighbors.astype(np.int32).sum() if neighbors >= 24: types.append("buried") elif neighbors < 16: types.append("surface") else: types.append("neutral") if neighbors >= 24: tresh = 0.98 else: tresh = 0.95 if FLAGS.sample_mode == "weighted_gauss": chis_list = interpolated_sample_normal(db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False) elif FLAGS.sample_mode == "gmm": chis_list = mixture_sample_normal(db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False) elif FLAGS.sample_mode == "rosetta": chis_list = exhaustive_sample(db, angles[idx, 1], angles[idx, 2], res[idx], tresh=tresh) neg_chis.append(chis_list) node_neg_embeds = [] length_chis = len(chis_list) for i in range(neg_sample): chis_target = angles[:, 4:8].copy() if i >= len(chis_list): node_neg_embed_copy = node_neg_embed.copy() node_neg_embeds.append(node_neg_embeds[i % length_chis]) neg_chis[-1].append(chis_list[i % length_chis]) continue chis = chis_list[i] chis_target[idx] = ( chis * chis_valid[idx, :4] + (1 - chis_valid[idx, :4]) * chis_target[idx]) pos_new = rotate_dihedral_fast(angles, par, child, pos, pos_exist, chis_target, chis_valid, idx) node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist) node_neg_embeds.append(node_neg_embed) node_neg_embeds = np.array(node_neg_embeds) dist = np.square(node_neg_embeds[:, :, -3:] - pos[idx:idx + 1, 4:5, :]).sum(axis=2) close_idx = np.argsort(dist) node_neg_embeds = np.take_along_axis(node_neg_embeds, close_idx[:, :64, None], axis=1) node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] / 10.0 node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] - np.mean( node_neg_embeds[:, :, -3:], axis=1, keepdims=True) node_embed_evals.append(node_neg_embeds) if len(node_embed_evals) == nminibatch or idx == (n_amino - 2): n_entries = len(node_embed_evals) node_embed_evals = np.concatenate(node_embed_evals) s = node_embed_evals.shape # For sample rotations per batch node_embed_evals = np.tile(node_embed_evals[:, None, :, :], (1, rotations, 1, 1)) rot_matrix = so3.rvs(rotations) if rotations == 1: rot_matrix = rot_matrix[None, :, :] node_embed_evals[:, :, :, -3:] = np.matmul( node_embed_evals[:, :, :, -3:], rot_matrix[None, :, :, :]) node_embed_evals = node_embed_evals.reshape((-1, *s[1:])) node_embed_feed = torch.from_numpy( node_embed_evals).float().cuda() with torch.no_grad(): energy = 0 if FLAGS.ensemble > 1: for model in models: energy_tmp = model.forward(node_embed_feed) energy = energy + energy_tmp else: energy = model.forward(node_embed_feed) energy = energy.view(n_entries, -1, rotations).mean(dim=2) select_idx = torch.argmin(energy, dim=1).cpu().numpy() for i in range(n_entries): select_idx_i = select_idx[i] valid_chi_idx = valid_chi_idxs[i] rotamer_score, _ = compute_rotamer_score_planar( gt_chis[i], neg_chis[i][select_idx_i], valid_chi_idx[:4], res_names[i]) rotamer_scores.append(rotamer_score) amino_recovery[str(res_names[i])] = amino_recovery[str( res_names[i])] + [rotamer_score] if types[i] == "buried": buried_scores.append(rotamer_score) elif types[i] == "surface": surface_scores.append(rotamer_score) gt_chis = [] node_embed_evals = [] neg_chis = [] valid_chi_idxs = [] res_names = [] types = [] counter += 1 rotamer_scores_total.append(np.mean(rotamer_scores)) if len(buried_scores) > 0: buried_scores_total.append(np.mean(buried_scores)) surface_scores_total.append(np.mean(surface_scores)) for k, v in amino_recovery.items(): if len(v) > 0: amino_recovery_total[k] = amino_recovery_total[k] + [ np.mean(v) ] print( "Obtained a rotamer recovery score of ", np.mean(rotamer_scores_total), np.std(rotamer_scores_total) / len(rotamer_scores_total)**0.5, ) print( "Obtained a buried recovery score of ", np.mean(buried_scores_total), np.std(buried_scores_total) / len(buried_scores_total)**0.5, ) print( "Obtained a surface recovery score of ", np.mean(surface_scores_total), np.std(surface_scores_total) / len(surface_scores_total)**0.5, ) for k, v in amino_recovery_total.items(): print( "per amino acid recovery of {} score of ".format(k), np.mean(v), np.std(v) / len(v)**0.5, )
def __init__( self, FLAGS, mmcif_path="./mmcif", split="train", rank_idx=0, world_size=1, uniform=True, weighted_gauss=False, gmm=False, chi_mean=False, valid=False, ): files = [] dirs = os.listdir(osp.join(mmcif_path, "mmCIF")) self.split = split self.so3 = special_ortho_group(3) self.chi_mean = chi_mean self.weighted_gauss = weighted_gauss self.gmm = gmm self.uniform = uniform # Filter out proteins in test dataset for d in tqdm(dirs): directory = osp.join(mmcif_path, "mmCIF", d) d_files = os.listdir(directory) files_tmp = [osp.join(directory, d_file) for d_file in d_files if ".p" in d_file] for f in files_tmp: name = f.split("/")[-1] name = name.split(".")[0] if name in test_rotamers and self.split == "test": files.append(f) elif name not in test_rotamers and self.split in ["train", "val"]: files.append(f) self.files = files if split in ["train", "val"]: duplicate_seqs = set() # Remove proteins in the train dataset that are too similar to the test dataset with open(osp.join(mmcif_path, "duplicate_sequences.txt"), "r") as f: for line in f: duplicate_seqs.add(line.strip()) fids = set() # Remove low resolution proteins with open( osp.join(mmcif_path, "cullpdb_pc90_res1.8_R0.25_d190807_chains14857"), "r" ) as f: i = 0 for line in f: if i is not 0: fid = line.split()[0] if fid not in duplicate_seqs: fids.add(fid) i += 1 files_new = [] alphabet = [] for letter in range(65, 91): alphabet.append(chr(letter)) for f in files: tup = (f.split("/")[-1]).split(".") if int(tup[1]) >= len(alphabet): continue seq_id = tup[0].upper() + alphabet[int(tup[1])] if seq_id in fids: files_new.append(f) self.files = files_new elif split == "test": fids = set() # Remove low resolution proteins with open( osp.join(mmcif_path, "cullpdb_pc90_res1.8_R0.25_d190807_chains14857"), "r" ) as f: i = 0 for line in f: if i is not 0: fid = line.split()[0] fids.add(fid) i += 1 files_new = [] alphabet = [] for letter in range(65, 91): alphabet.append(chr(letter)) for f in files: tup = (f.split("/")[-1]).split(".") if int(tup[1]) >= len(alphabet): continue seq_id = tup[0].upper() + alphabet[int(tup[1])] if seq_id in fids: files_new.append(f) self.files = files_new chunksize = len(self.files) // world_size n = len(self.files) # Set up a validation dataset if split == "train": n = self.files[int(0.95 * n) :] elif split == "val": n = self.files[: int(0.95 * n)] self.FLAGS = FLAGS self.db = load_rotamor_library() print(f"Loaded {len(self.files)} files for {split} dataset split") self.split = split