def test_silent_io(self): with tempfile.TemporaryDirectory() as workdir: # Tests if single pose is written to silent file and returned unchanged. test_packed_pose = io.pose_from_sequence("TESTTESTTEST") test_pose = packed_pose.to_pose(test_packed_pose) tmp_file = os.path.join(workdir, "temp.silent") io.to_silent(test_packed_pose, tmp_file) returned_poses = [ packed_pose.to_pose(p) for p in io.poses_from_silent(tmp_file) ] # Tests that the amino acid sequences does not change after # being written to silent file and read back in. # Cannot just assert that the two poses are equal, # because writing them to the silent file adds score lines. self.assertEqual(test_pose.sequence(), returned_poses[0].sequence(), msg="Single sequence recovery failed.") # Tests that the positions of atoms are almost identical after # being written to silent file and read back in. # Rmsd will not quite be equal because the output silent file # truncates xyz coordinates to the third decimal place. self.assertAlmostEqual( 0., pyrosetta.rosetta.core.scoring.all_atom_rmsd( test_pose, returned_poses[0]), places=3, msg="Single position recovery failed.") # Test if a list of poses can be written to a silent file # and returned unchanged. test_packed_poses = [ io.pose_from_sequence("TEST" * i) for i in range(1, 5) ] test_poses = [packed_pose.to_pose(p) for p in test_packed_poses] tmp_file = os.path.join(workdir, "temp_list.silent") io.to_silent(test_packed_poses, tmp_file) returned_poses = [ packed_pose.to_pose(p) for p in io.poses_from_silent(tmp_file) ] for i in range(len(test_poses)): self.assertEqual(test_poses[i].sequence(), returned_poses[i].sequence(), msg="List sequence recovery failed.") self.assertAlmostEqual( 0., pyrosetta.rosetta.core.scoring.all_atom_rmsd( test_poses[i], returned_poses[i]), places=3, msg="List position recovery failed.")
def test_rosetta_scripts(self): test_protocol = """ <ROSETTASCRIPTS> <TASKOPERATIONS> <RestrictToRepacking name="repack"/> </TASKOPERATIONS> <MOVERS> <PackRotamersMover name="pack" task_operations="repack"/> </MOVERS> <PROTOCOLS> <Add mover="pack"/> </PROTOCOLS> </ROSETTASCRIPTS> """ test_pose = io.pose_from_sequence("TEST") test_task = rosetta_scripts.SingleoutputRosettaScriptsTask( test_protocol) logging.info("dask client: %s", self.client) task = self.client.submit(test_task, test_pose) result = task.result() self.assertEqual( packed_pose.to_pose(result).sequence(), packed_pose.to_pose(test_pose).sequence())
def test_concurrent_score(self): test_pose = io.pose_from_sequence("TEST" * 10) import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=3) as p: result = p.map(score.ScorePoseTask(), [test_pose] * 3)
def test_gil_score(self): with HeartBeat(10e-3) as hb: test_pose = io.pose_from_sequence("TESTTESTTEST" * 10) score.ScorePoseTask()(test_pose) numpy.testing.assert_allclose(hb.beat_intervals, hb.interval, rtol=1) self.assertGreater(len(hb.beats), 4)
def setUp(self, local_dir=workdir): if not os.path.isdir(local_dir): os.mkdir(local_dir) poses = [io.pose_from_sequence("TEST" * i) for i in range(1, 4)] for i, pose in enumerate(poses, start=1): with open(os.path.join(local_dir, "tmp_{0}.pdb".format(i)), "w") as f: f.write(io.to_pdbstring(pose))
def test_score_smoke_test(self): """RosettaScripts 'null' tasks just score with default score function. A high-level smoke test of the distributed namespace. Inits a packed-pose object via a call through the io layer, passes this through the RosettaScripts layer, then through the scoring layers. Covers pose serializiation/deserialization, score value extraction and rosetta scripts parser access. Which is to say, turn on the power and look for magic smoke. """ score_task = score.ScorePoseTask() rs_task = rosetta_scripts.SingleoutputRosettaScriptsTask(self.min_rs) score_result = score_task(io.pose_from_sequence("TEST")).scores rs_result = rs_task(io.pose_from_sequence("TEST")).scores self.assertAlmostEqual(score_result["total_score"], rs_result["total_score"])
def run_task(seq): test_pose = io.pose_from_sequence(seq) protocol = rosetta_scripts.SingleoutputRosettaScriptsTask(""" <ROSETTASCRIPTS> <MOVERS> <FastRelax name="score" repeats="1" /> </MOVERS> <PROTOCOLS> <Add mover_name="score"/> </PROTOCOLS> </ROSETTASCRIPTS> """) return protocol(test_pose)
def test_update_score(self): """PackedPose.update_score returns an updated copy. PackedPose.update_score performs an copy-update of the pack, not an inplace modification. New score values are applied from kwargs and args, with kwargs and later-arg masking duplicate values. """ work_pose = io.pose_from_sequence("TEST") self.assertDictEqual(work_pose.scores, dict()) # Test merge and masking, just args work_updated = work_pose.update_scores({ "arg": 1.0, "dupe": "foo" }, dict(arg2=2.0, dupe="bar")) self.assertDictEqual(work_pose.scores, dict()) self.assertDictEqual(work_updated.scores, dict(arg=1.0, arg2=2.0, dupe="bar")) # Test merge and masking, args and kwargs work_updated = work_pose.update_scores({ "arg": 1.0, "dupe": "foo" }, dict(arg2=2.0, dupe="bar"), kwarg="yes", dupe="bat") self.assertDictEqual(work_pose.scores, dict()) self.assertDictEqual(work_updated.scores, dict(arg=1.0, arg2=2.0, kwarg="yes", dupe="bat")) # Test just kwargs work_updated = work_pose.update_scores(kwarg="yes", dupe="bat") self.assertDictEqual(work_pose.scores, dict()) self.assertDictEqual(work_updated.scores, dict(kwarg="yes", dupe="bat")) # Test no args work_updated = work_pose.update_scores() self.assertDictEqual(work_pose.scores, dict()) self.assertDictEqual(work_updated.scores, dict())
def test_concurrent_on_task(self): protocol = rosetta_scripts.SingleoutputRosettaScriptsTask(""" <ROSETTASCRIPTS> <MOVERS> <FastRelax name="score" repeats="1"/> </MOVERS> <PROTOCOLS> <Add mover_name="score"/> </PROTOCOLS> </ROSETTASCRIPTS> """) test_pose = io.pose_from_sequence("TEST") import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=3) as p: result = list(p.map(protocol, [test_pose] * 3))
def test_worker_extra(self): """worker_extra worker plugin controls rosetta flags and local dir. pyrosetta.distributed.dask.worker_extra can be used to specify a set of PyRosetta initialization flags and local directory information that will be used to pre-initialize all worker processes in a dask cluster. This provides support for protocols which may require command line flags including (e.g. additional residue types, logging flags, ...). """ # Import locally to allow test discovery if dependencies aren't present from dask import delayed from dask.distributed import Client, LocalCluster import pyrosetta import pyrosetta.distributed.dask import pyrosetta.distributed.io as io import pyrosetta.distributed.tasks.score as score # Setup cluster scheduler and working directory with tempfile.TemporaryDirectory() as workdir, LocalCluster( n_workers=0, diagnostics_port=None) as cluster: # Context manager controls launch & teardown of test worker @contextlib.contextmanager def one_worker(init_flags): init = pyrosetta.distributed.dask.worker_extra( init_flags=init_flags, local_directory=workdir) worker_command = ( "%s -m distributed.cli.dask_worker %s --nthreads 1 %s" % (sys.executable, cluster.scheduler_address, " ".join(init)), ) worker = subprocess.Popen(worker_command, shell=True) try: wait = 10.0 while len(cluster.scheduler.workers) == 0: worker.poll() assert worker.returncode is None time.sleep(.1) wait -= .1 assert wait > 0, "Timeout waiting for worker launch." yield None finally: worker.terminate() try: worker.wait(1) except TimeoutError: worker.kill() worker.wait() HBI_fa_params_path = os.path.join(workdir, "HBI.fa.params") with open(HBI_fa_params_path, "w") as of: of.write(_HBI_fa_params) # Unformatted PyRosetta command line flags flags = """ -extra_res_fa {0} #Test flag comment -ignore_unrecognized_res 1 -use_input_sc 1 -ex4 ### Test flag comment """.format(HBI_fa_params_path) # Initialize PyRosetta pyrosetta.distributed.dask.init_notebook(flags) # Setup protocol protocol = score.ScorePoseTask("ref2015_cart") # Setup pose pose = io.pose_from_sequence("TESTING [HBI]") # Score locally local_result = protocol(pose) # Score on dask-worker with correct command-line flags with one_worker(flags), Client(cluster): cluster_result = delayed(protocol)(pose).compute() self.assertEqual(cluster_result.scores["total_score"], local_result.scores["total_score"])