def pose2pandas(pose: pyrosetta.Pose, scorefxn: pyrosetta.ScoreFunction) -> pd.DataFrame: """ Return a pandas dataframe from the scores of the pose :param pose: :return: """ pose.energies().clear_energies() scorefxn(pose) scores = pd.DataFrame(pose.energies().residue_total_energies_array()) pi = pose.pdb_info() scores['residue'] = scores.index.to_series() \ .apply(lambda r: pose.residue( r +1) \ .name1() + pi.pose2pdb( r +1) ) return scores
def add_bfactor_from_score(pose: pyrosetta.Pose): """ Adds the bfactors from total_score. Snippet for testing in Jupyter >>> import nglview as nv >>> view = nv.show_rosetta(pose) >>> # view = nv.show_file('test.cif') >>> view.clear_representations() >>> view.add_tube(radiusType="bfactor", color="bfactor", radiusScale=0.10, colorScale="RdYlBu") >>> view ``replace_res_remap_bfactors`` may have been a cleaner strategy. This was quicker to write. If this fails, it may be because the pose was not scored first. """ if pose.pdb_info().obsolete(): raise ValueError( 'Pose pdb_info is flagged as obsolete (change `pose.pdb_info().obsolete(False)`)' ) # scores energies = pose.energies() def get_res_score(res): total_score = pyrosetta.rosetta.core.scoring.ScoreType.total_score # if pose.residue(res).is_polymer() try: return energies.residue_total_energies(res)[total_score] except: return float('nan') # the array goes from zero (nan) to n_residues total_scores = np.array( [float('nan')] + [get_res_score(res) for res in range(1, pose.total_residue() + 1)]) mask = np.isnan(total_scores) total_scores -= np.nanmin(total_scores) total_scores *= 100 / np.nanmax(total_scores) total_scores = np.nan_to_num(total_scores, nan=100) total_scores[mask] = 0. # add to pose pdb_info = pose.pdb_info() for res in range(pose.total_residue()): for i in range(pose.residue(res + 1).natoms()): pdb_info.bfactor(res + 1, i + 1, total_scores[res + 1])
def get_scoredict(self, pose: pyrosetta.Pose) -> Dict[str, float]: """ Given a pose get the global scores. """ a = pose.energies().total_energies_array() return dict(zip(a.dtype.fields.keys(), a.tolist()[0]))